玩游戏

Description

对于 (1leq n,m,kleq 10^5) ,给定 ({a_n})({b_m}),对所有 (tin[1,k])

[frac{1}{nm}sum_{i=1}^n sum_{j=1}^m (a_i+b_j)^t ]

Solution

展开后面,有

[egin{align} &c_t=sum_{i=1}^n sum_{j=1}^m sum_{r=0}^t inom{t}{r} a_i^r b_j^{t-r}\ =&sum_{r=0}^t inom{t}{r} Big( sum_{i=1}^n a_i^rBig) Big( sum_{j=1}^m b_j^{t-r}Big) end{align} ]

(A(x)=sum_{wgeq 0} (sum_{i=1}^n a_i^w) x^w),有

[egin{align} A(x)&=sum_{i=1}^n sum_{wgeq 0} a_i^w x^w \ &=sum_{i=1}^n frac{1}{1-a_ix} \ &=n-sum_{i=1}^n frac{-a_ix}{1-a_ix} \ &=n-xsum_{i=1}^n frac{-a_i}{1-a_ix} \ &=n-xsum_{i=1}^n ln (1-a_ix)' \ &=n-xlnBig(prod_{i=1}^n (1-a_ix)Big)' end{align} ]

于是只需要分治乘法和 (ln) 即可求出 (A(x)),再求出其 EGF 形式。(B)(A) 同理。最后 (C) 就是 (A)(B) 的二项卷积。不要忘了除 (nm)

#include<stdio.h>
#define rint register int

typedef long long ll;

inline int read(){
    int x=0,flag=1; char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')flag=0;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
    return flag? x:-x; 
}

const int N=(1<<21)+7;
const int Mod=998244353;
const int G=3;

ll qpow(ll x,int y=Mod-2){
    ll ret=1;
    while(y){
        if(y&1) ret=ret*x%Mod;
        x=x*x%Mod,y>>=1; 
    }
    return ret;
}

const int Gi=qpow(G);

int rk[N];
inline void swap(ll &x,ll &y){x^=y,y^=x,x^=y;}
void NTT(bool op,int n,ll *F){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        rint len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(rint k=0;k<n;k+=p){
            ll now=1;
            for(rint l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(rint i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(rint i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n<<1,x); Inv(n,a,x);
    for(rint i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(rint i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

void Solve(int l,int r,ll *a,ll *b){
    int len=r-l+1;
    if(l==r){b[0]=1,b[1]=(Mod-a[l])%Mod;return;}
    int mid=(l+r)>>1,n=1; for(;n<=len;n<<=1);
    ll Lf[N],Rf[N];
    Clear(n,Lf),Clear(n,Rf);
    Solve(l,mid,a,Lf),Solve(mid+1,r,a,Rf); Rk(n);
    NTT(1,n,Lf),NTT(1,n,Rf);
    for(rint i=0;i<n;i++) b[i]=Lf[i]*Rf[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<=len;i++) b[i]=b[i]*inv%Mod;
    for(rint i=len+1;i<n;i++) b[i]=0;
}

int n,m;
ll a[N],b[N],A[N],B[N],fac[N],inv[N];

inline int max(int x,int y){return x>y? x:y;}

int main(){
    n=read(),m=read();
    for(rint i=1;i<=n;i++) a[i]=read();
    for(rint i=1;i<=m;i++) b[i]=read();
    int t=read(); int k=max(max(n,m),t); 
    Solve(1,n,a,A),Solve(1,m,b,B);
    ln(k+2,A,a),ln(k+2,B,b); 
    for(rint i=0;i<=k;i++)
        a[i]=a[i+1]*(i+1)%Mod,b[i]=b[i+1]*(i+1)%Mod;
    a[k+1]=b[k+1]=0;
    A[0]=n; B[0]=m;
    for(rint i=1;i<=k;i++)
        A[i]=(Mod-a[i-1])%Mod,B[i]=(Mod-b[i-1])%Mod;
    fac[0]=1; int rg=k<<1;
    for(int i=1;i<=rg;i++) fac[i]=fac[i-1]*i%Mod;
    inv[rg]=qpow(fac[rg]);
    for(rint i=rg-1;~i;i--) inv[i]=inv[i+1]*(i+1)%Mod;
    for(rint i=0;i<=k;i++)
        A[i]=A[i]*inv[i]%Mod,B[i]=B[i]*inv[i]%Mod;
    ll ret=1ll*n*m%Mod;
    for(m=rg,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,A),NTT(1,n,B);
    for(rint i=0;i<n;i++) A[i]=A[i]*B[i]%Mod;
    NTT(0,n,A); ll inv_=qpow(n)*qpow(ret)%Mod;
    for(rint i=1;i<=t;i++) printf("%lld
",A[i]*inv_%Mod*fac[i]%Mod);
}
原文地址:https://www.cnblogs.com/wwlwQWQ/p/15026512.html