[HAOI2018]染色(NTT)

前置芝士

可重集排列
NTT

前置定义

[egin{aligned}\ f_i=C_m^icdot frac{n!}{(S!)^i(n-iS)!}cdot (m-i)^{n-iS}\ ans_i=sumlimits_{j=i}^lim (-1)^{j-i}C_j^i f_j\ end{aligned}]

理解:(m)种颜色选i种恰好出现(S)次,可重全排列,剩余块染色,不过这样有可能会出现剩余块种有恰好出现(S)次的情况,所以容斥一下

(C_j^i)(f_j)里一定包含着(f_i),要减掉,同时有可能减掉了在原本(j+1..)的东西

推式

[egin{aligned}\ ans_i=sumlimits_{j=i}^{lim} (-1)^{j-i}frac{j!}{i!(j-i)!}f_j\ ans_icdot i!=sumlimits_{j=i}^{lim}(frac{(-1)^{j-i}}{(j-i)!})cdot (f_jcdot j!)\ end{aligned}]

设生产函数(G,F)分别对应((frac{(-1)^{j-i}}{(j-i)!}),(f_jcdot j!)),再把(F)翻转一下:

[egin{aligned}\ ans_icdot i!&=sumlimits_{j=i}^{lim}G_{j-i}cdot F_{lim-j}\ H&=G*F\ ans_icdot i!&=H_{lim-i}\ end{aligned}]

Code

(NTT)模板就行

#include<bits/stdc++.h>
typedef long long LL;
const LL mod=1004535809,gg=3,maxn=1e7+9;
inline LL Read(){
    LL x(0),f(1); char c=getchar();
    while(c<'0' || c>'9'){
        if(c=='-') f=-1; c=getchar();
    }
    while(c>='0' && c<='9'){
        x=(x<<3)+(x<<1)+c-'0'; c=getchar();
    }
    return x*f;
}
inline LL Pow(LL base,LL b){
    LL ret(1);
    while(b){
        if(b&1) ret=ret*base%mod; base=base*base%mod; b>>=1;
    }return ret;
}
LL fac[maxn],fav[maxn],r[maxn];
inline LL Get_c(int n,int m){
    return fac[n]*fav[m]%mod*fav[n-m]%mod;
}
inline LL Fir(LL n){
    LL limit(1),len(0);
    while(limit<(n<<1)){
        limit<<=1; ++len;
    }
    for(int i=0;i<limit;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len-1);
    return limit;
}
inline void NTT(LL *a,int n,int type){
    for(int i=0;i<n;++i) if(i<r[i]) std::swap(a[i],a[r[i]]);
    for(LL mid=1;mid<n;mid<<=1){
        LL wn(Pow(gg,(mod-1)/(mid<<1)));
        if(type==-1) wn=Pow(wn,mod-2);
        for(LL R=mid<<1,j=0;j<n;j+=R){
            for(LL k=0,w=1;k<mid;++k,w=w*wn%mod){
                LL x(a[j+k]),y(a[j+mid+k]*w%mod);
                a[j+k]=(x+y)%mod; a[j+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if(type==-1){
        LL ty(Pow(n,mod-2));
        for(int i=0;i<n;++i) a[i]=a[i]*ty%mod;
    }
}
LL n,m,S,lim,ret;
LL W[maxn],f[maxn],g[maxn],h[maxn],ans[maxn];
int main(){
    n=Read(); m=Read(); S=Read();
    for(int i=0;i<=m;++i) W[i]=Read();
    lim=std::min(m,n/S);
    fac[0]=fac[1]=1;
    int up(std::max(n,m));
    for(int i=2;i<=up;++i) 
        fac[i]=fac[i-1]*i%mod;
    fav[up]=Pow(fac[up],mod-2);
    for(int i=up;i>=1;--i) 
        fav[i-1]=fav[i]*i%mod;
    for(int i=0;i<=lim;++i)
        f[i]=Get_c(m,i)*fac[n]%mod* Pow(Pow(fac[S],i),mod-2)%mod *fav[n-i*S]%mod *Pow(m-i,n-i*S)%mod *fac[i]%mod;
    for(int i=0;i<=(lim>>1);++i) 
        std::swap(f[i],f[lim-i]);
    for(int i=0;i<=lim;++i)
        g[i]=(Pow(-1,i)*fav[i]+mod)%mod;
    LL limit(Fir(lim+1));
    NTT(f,limit,1); NTT(g,limit,1);
    for(int i=0;i<limit;++i) h[i]=g[i]*f[i]%mod;
    NTT(h,limit,-1);
    
    for(int i=0;i<=lim;++i) ans[i]=h[lim-i]*fav[i]%mod;
    for(int i=0;i<=lim;++i) ret=(ret+ans[i]*W[i]%mod)%mod;	
    printf("%lld
",ret);
    return 0;
}
原文地址:https://www.cnblogs.com/y2823774827y/p/10699933.html