[luogu4705] 玩游戏

题目链接

洛谷:https://www.luogu.org/problemnew/show/P4705

Solution

精神污染....这玩意比数树还难写...就是窝太菜了代码过于冗长然后调试还写了两三K

思路据说比较套路??反正窝是不会


我们可以很容易的把答案写出来:

[ans_k=frac{1}{nm}sum_{i=1}^{n}sum_{j=1}^m (a_i+b_j)^k ]

我们忽略掉那个(nm)的系数,最后乘回来就好了,然后化简下:

[egin{align} ans_k=&sum_{i=1}^{n}sum_{j=1}^{m}sum_{x=0}^kinom{k}{x}a_i^xb_i^{k-x}\ =&sum_{x=0}^kinom{k}{x}sum_{i=1}^{n}a_i^xsum_{j=1}^{m}b_i^{k-x} end{align} ]

显然这是个卷积形式,然后算法瓶颈就在如何求出(f(x)=sum_{i=1}^{n}a_i^x)

我们写出这个玩意的生成函数:

[egin{align} F(x)=&sum_{i=0}^{infty} f(i)x^i\ =&sum_{i=0}^{infty}sum_{j=1}^{n}a_j^ix^i\ =&sum_{j=1}^{n}sum_{i=0}^{infty} a_j^ix^i\ =&sum_{j=1}^{n}frac{1}{1-a_jx} end{align} ]

注意到:

[(ln (1-ax))'=frac{-a}{1-ax} ]

即:

[x(ln(1-ax))'=1-frac{1}{1-ax} ]

也就是说:

[F(x)=sum_{i=1}^{n}1-x(ln(1-a_ix))'=n-xsum_{i=1}^{n}(ln(1-a_ix))' ]

注意到导数满足加法律,后面的在化一下就是:

[egin{align} F(x)=&n-xleft(sum_{i=1}^{n}ln (1-a_ix) ight)'\ =&n-xleft(ln prod_{i=1}^{n}(1-a_ix) ight)' end{align} ]

注意到里面的连乘形式可以分治( m FFT)(O(nlog ^2 n))解决,然后再照着式子算一下就好了,需要写个多项式求(ln),注意要把前面忽略的东西弄回去。

总复杂度(O(nlog^2 n))

代码大概还能凑合着看吧...

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}

#define lf double
#define ll long long 

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(register int i=l,r_##i=r;i<=r_##i;++i) 

const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;

int a[maxn],b[maxn],fac[maxn],ifac[maxn],inv[maxn];
int w[maxn],n,m,T,mxn,bit,N,tmp[15][maxn],pos[maxn];

int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}

int qpow(int aa,int x) {
    int res=1;
    for(;x;x>>=1,aa=mul(aa,aa)) if(x&1) res=mul(res,aa);
    return res;
}

void clear(int *l,int *r) {
    if(l>=r) return ;
    while(l!=r) *l++=0;*l=0;
}

void ntt_init(int len) {
    for(mxn=1;mxn<=len;mxn<<=1);
    w[0]=1;w[1]=qpow(3,(mod-1)/mxn);
    for(int i=2;i<=mxn;i++) w[i]=mul(w[i-1],w[1]);

    inv[0]=inv[1]=fac[0]=ifac[0]=1;
    for(int i=2;i<=mxn;i++) inv[i]=mul(mod-mod/i,inv[mod%i]);
    for(int i=1;i<=mxn;i++) fac[i]=mul(fac[i-1],i);
    for(int i=1;i<=mxn;i++) ifac[i]=mul(ifac[i-1],inv[i]);
}

void get(int len) {for(bit=0,N=1;N<=len;N<<=1,bit++);}

void get_pos() {for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));}

void ntt(int *r,int op) {
    for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
        for(int j=0;j<N;j+=i<<1)
            for(int k=0;k<i;k++) {
                int x=r[j+k],y=mul(r[i+j+k],w[k*d]);
                r[j+k]=add(x,y),r[i+j+k]=del(x,y);
            }
    if(op==-1) {
        reverse(r+1,r+N);int d=qpow(N,mod-2);
        for(int i=0;i<N;i++) r[i]=mul(r[i],d);
    }
}

void poly_inv(int *r,int *t,int len) {
    if(len==1) return t[0]=qpow(r[0],mod-2),void();
    poly_inv(r,t,len>>1);get(len);get_pos();
    for(int i=0;i<len;i++) tmp[0][i]=r[i],tmp[1][i]=t[i];
    clear(tmp[0]+len,tmp[0]+N);
    clear(tmp[1]+len,tmp[1]+N);
    ntt(tmp[0],1),ntt(tmp[1],1);
    for(int i=0;i<N;i++) t[i]=del(mul(2,tmp[1][i]),mul(mul(tmp[0][i],tmp[1][i]),tmp[1][i]));
    ntt(t,-1),clear(t+len,t+N);
}

void poly_der(int *r,int *t,int len) {
    get(len);
    for(int i=1;i<len;i++) t[i-1]=mul(r[i],i);
    clear(t+len-1,t+N);
}

void poly_ln(int *r,int *t,int len) {
    poly_inv(r,tmp[3],len);
    poly_der(r,tmp[4],len);
    get(len),get_pos();
    clear(tmp[3]+len,tmp[3]+N),clear(tmp[4]+len,tmp[4]+N);
    ntt(tmp[3],1),ntt(tmp[4],1);
    for(int i=0;i<N;i++) tmp[5][i]=mul(tmp[3][i],tmp[4][i]);
    ntt(tmp[5],-1);
    for(int i=0;i<len;i++) t[i+1]=mul(inv[i+1],tmp[5][i]);t[0]=0;
    clear(t+len+1,t+N);
}

vector<int > p[maxn];

void dc_fft(int *s,int x,int l,int r) {  // divide and conquar of NTT
    if(l==r) return p[x].resize(2),p[x][0]=1,p[x][1]=del(0,s[l]),void();
    int mid=(l+r)>>1,ls=x<<1,rs=x<<1|1;
    dc_fft(s,ls,l,mid),dc_fft(s,rs,mid+1,r);
    int r1=mid-l+1,r2=r-mid;
    for(int i=0;i<=r1;i++) tmp[6][i]=p[ls][i];
    for(int i=0;i<=r2;i++) tmp[7][i]=p[rs][i];
    get(r1+r2),get_pos(); 
    clear(tmp[6]+r1+1,tmp[6]+N),clear(tmp[7]+r2+1,tmp[7]+N);
    ntt(tmp[6],1),ntt(tmp[7],1);
    FOR(i,0,N-1) tmp[8][i]=mul(tmp[6][i],tmp[7][i]);
    ntt(tmp[8],-1);
    p[x].resize(r1+r2+1);
    for(int i=0;i<=r1+r2;i++) p[x][i]=tmp[8][i];
    p[ls].clear(),p[rs].clear();
}

int A[maxn],B[maxn];

void solve(int *s,int *t,int len) {
    dc_fft(s,1,1,len);get(len);int nn=N;
    FOR(i,0,p[1].size()-1) tmp[9][i]=p[1][i];get(T);nn=N;
    clear(tmp[9]+p[1].size(),tmp[9]+N);
    p[1].clear();
    poly_ln(tmp[9],tmp[10],nn);
    poly_der(tmp[10],tmp[11],nn);
    for(int i=1;i<nn;i++) t[i]=del(0,tmp[11][i-1]);t[0]=len;
    for(int i=0;i<nn;i++) t[i]=mul(t[i],ifac[i]);
}

int main() {
    read(n),read(m);FOR(i,1,n) read(a[i]);FOR(i,1,m) read(b[i]);read(T);
    ntt_init((max(n+m,T))<<1);
    solve(a,A,n);
    solve(b,B,m);
    get(T<<1),get_pos();
    clear(A+T+1,A+N),clear(B+T+1,B+N);ntt(A,1),ntt(B,1);
    for(int i=0;i<N;i++) A[i]=mul(A[i],B[i]);
    ntt(A,-1);
    for(int i=1;i<=T;i++) write(mul(mul(fac[i],A[i]),mul(inv[n],inv[m])));
    return 0;
}
原文地址:https://www.cnblogs.com/hbyer/p/10750971.html