[BZOJ5306] [HAOI2018]染色

题目描述

为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可以抽象为一个长度为 N 的序列, 每个位置都可以被染成 M 种颜色中的某一种.

然而小 C 只关心序列的 N 个位置中出现次数恰好为 S 的颜色种数, 如果恰好出现了 S 次的颜色有 K 种, 则小C会产生 W_k 的愉悦度.

小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 1004535809 取模的结果是多少.

输入格式

第一行三个整数 N, M, S.

接下来一行 M + 1 个整数, 第 i 个数表示 W_{i-1}.

输出格式

输出一行一个整数表示答案.

样例输入

8 8 3
3999 8477 9694 8454 3308 8961 3018 2255 4910

样例输出

524070430

Solution

(f(i))表示恰好(i)种颜色有(s)个的方案数,(g(i))表示至少的方案数。

那么可以得到:

[g(i)=sum_{j=i}^Minom{M}{j}f(j)=inom{m}{i}(m-i)^{n-si}frac{(si)!}{(s!)^i}inom{n}{si} ]

后面的也比较好理解,第一项枚举放哪些颜色,第二项剩下的颜色随便选,后面两项枚举顺序和硬点的颜色放的位置。

注意这里(M=min(m,lfloor n/s floor))

那么对这玩意广义容斥一波可以得到:

[egin{align} f(i)&=sum_{j=i}^{M}(-1)^{j-i}inom{j}{i}g(j) \&=sum_{j=i}^{M}(-1)^{j-i}inom{j}{i}inom{m}{j}(m-j)^{n-sj}frac{(sj)!}{(s!)^j}inom{n}{sj}\ &=sum_{j=i}^{M}(-1)^{j-i}inom{j}{i}inom{m}{j}(m-j)^{n-sj}frac{n!}{(s!)^j(n-sj)!}\ end{align} ]

(f)化简一下:

[egin{align}f(i)&=sum_{j=i}^{M}(-1)^{j-i}inom{j}{i}inom{m}{j}(m-j)^{n-sj}frac{(sj)!}{(s!)^j}inom{n}{sj}\ &=frac{1}{i!} imessum_{j=i}^{M}(-1)^{j-i}frac{1}{(j-i)!} imes frac{m!}{(m-j)!}(m-j)^{n-sj}frac{n!}{(s!)^j(n-sj)!}\end{align} ]

那么注意到这是一个卷积的形式,设:

[s(i)=(-1)^ifrac{1}{i!},t(i)=frac{m!}{(m-i)!}(m-i)^{n-si}frac{n!}{(s!)^j(n-si)!}\ ]

那么:

[f(i)=frac{1}{i!}sum_{j=i}^{M}s(j-i)t(j) ]

这玩意可以(reverse) (t)之后和(s)卷起来,然后再(reverse)一次,就可以得到后面的卷积了。

具体的可以自己照着上面的步骤推一下。

然后算答案就直接:

[ans=sum_{i=0}^{M}w_if(i) ]

我才没有因为一开始的g(i)写错了调了一晚上呢

#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}

#define lf double
#define ll long long 

const int maxn = 6e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 1004535809;

int n,m,S,w[maxn],s[maxn],t[maxn],fac[maxn*100],ifac[maxn*100],r[maxn],N,bit,pos[maxn];

int qpow(int a,int x) {
	int res=1;
	for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
	return res;
}

void ntt(int *a,int f) {
	for(int i=0;i<N;i++) if(pos[i]>i) swap(a[i],a[pos[i]]);
	for(int i=1;i<N;i<<=1) {
		int wn=qpow(f==1?3:qpow(3,mod-2),(mod-1)/(i<<1));
		for(int j=0,ww=1;j<N;j+=(i<<1),ww=1)
			for(int k=0;k<i;k++,ww=1ll*ww*wn%mod) {
				int x=a[j+k]%mod,y=1ll*a[i+j+k]*ww%mod;
				a[j+k]=(x+y)%mod,a[i+j+k]=(x-y)%mod;
			}
	}
	if(f==-1) {
		int inv=qpow(N,mod-2);
		for(int i=0;i<N;i++) a[i]=1ll*a[i]*inv%mod;
	}
}

int main() {
	read(n),read(m),read(S);int M=min(m,n/S);
	int L=max(n,m);
	for(int i=0;i<=m;i++) read(w[i]);
	fac[0]=ifac[0]=1;
	for(int i=1;i<=L;i++) fac[i]=1ll*fac[i-1]*i%mod;
	ifac[L]=qpow(fac[L],mod-2);
	for(int i=L-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
	for(int i=0,p=1;i<=M;i++,p=-p) s[i]=(p*ifac[i]+mod)%mod;
	for(int i=0;i<=M;i++) t[i]=1ll*fac[m]*ifac[m-i]%mod*qpow(m-i,n-S*i)%mod*fac[n]%mod*qpow(ifac[S],i)%mod*ifac[n-S*i]%mod;
	
	reverse(t,t+M+1);
    N=1;while(N<=M<<1) N<<=1,bit++;
	for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
	ntt(s,1),ntt(t,1);
	for(int i=0;i<N;i++) r[i]=1ll*s[i]*t[i]%mod;
	ntt(r,-1);
	reverse(r,r+M+1);
	int ans=0;
	for(int i=0;i<=M;i++) ans=(ans+1ll*r[i]*ifac[i]%mod*w[i]%mod)%mod;
	write((ans+mod)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/hbyer/p/10550248.html