【BZOJ3625】【codeforces438E】小朋友和二叉树 生成函数+多项式求逆+多项式开根

首先,我们构造一个函数$G(x)$,若存在$k∈C$,则$[x^k]G(x)=1$。

不妨设$F(x)$为最终答案的生成函数,则$[x^n]F(x)$即为权值为$n$的神犇二叉树个数。

不难推导出,$[x^n]F(x)=sum_{i=0}^{n}[x^i]G(x)sum_{j=0}^{n-i}[x^j]F(j) imes [x^{n-j-i}]F(n-j-i)$。

(这个式子的意思就是说,不妨设当前根节点的权值为i,然后枚举左右两个子树的权值)

这个式子显然可以通过动规的方式去推,从而得出答案,优化后的时间复杂度是$O(n^2)$的,显然不行。

我们对式子进行化简,考虑到$[x^0]F(x)=1$,那么$F(x)=G(x) imes F^2(x)+1$。

通过移项,得到$G imes F^2-F+1=0$,是一个关于$F$的一元二次方程。

由于多项式$G(x)$是已知的,那么我们就可以通过求根公式解出$F(x)$。

套入求根公式,得到$F(x)=frac{1±sqrt{1-4G}}{2G}$。

考虑到$F(0)=1$,$G(0)=0$,那么$F(x)=frac{1-sqrt{1-4G}}{2G}$

分子分母同时乘上$1+sqrt{1-4G}$,化简得到$F(x)=frac{2}{1+sqrt{1-4G}}$。

然后就是多项式开根+多项式求逆了。

#include<bits/stdc++.h>
#define M (1<<18)
#define L long long
#define MOD 998244353
#define inv2 499122177
#define G 3
using namespace std;

L pow_mod(L x,L k){
    L ans=1;
    while(k){
        if(k&1) ans=ans*x%MOD;
        x=x*x%MOD; k>>=1;
    }
    return ans;
}

void change(L a[],int n){
    for(int i=0,j=0;i<n-1;i++){
        if(i<j) swap(a[i],a[j]);
        int k=n>>1;
        while(j>=k) j-=k,k>>=1;
        j+=k;
    }
}
void NTT(L a[],int n,int on){
    change(a,n);
    for(int h=2;h<=n;h<<=1){
        L wn=pow_mod(G,(MOD-1)/h);
        for(int j=0;j<n;j+=h){
            L w=1;
            for(int k=j;k<j+(h>>1);k++){
                L u=a[k],t=w*a[k+(h>>1)]%MOD;
                a[k]=(u+t)%MOD;
                a[k+(h>>1)]=(u-t+MOD)%MOD;
                w=w*wn%MOD;
            }
        }
    }
    if(on==-1){
        L inv=pow_mod(n,MOD-2);
        for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD;
        reverse(a+1,a+n);
    }
}

void getinv(L a[],L b[],int n){
    if(n==1){b[0]=pow_mod(a[0],MOD-2); return;}
    static L c[M],d[M];
    memset(c,0,n<<4); memset(d,0,n<<4);
    getinv(a,c,n>>1);
    for(int i=0;i<n;i++) d[i]=a[i];
    NTT(d,n<<1,1); NTT(c,n<<1,1);
    for(int i=0;i<(n<<1);i++) b[i]=(2*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD;
    NTT(b,n<<1,-1);
    for(int i=0;i<n;i++) b[n+i]=0;
}

void sqrt(L a[],L b[],int n){
    if(n==1) return void(b[0]=1);
    sqrt(a,b,n>>1);
    static L invb[M],d[M]; 
    memset(invb,0,M<<3); memset(d,0,M<<3);
    getinv(b,invb,n);
    for(int i=0;i<n;i++) d[i]=a[i];
    NTT(b,n<<1,1); NTT(d,n<<1,1); NTT(invb,n<<1,1);
    for(int i=0;i<(n<<1);i++) b[i]=inv2*(b[i]+d[i]*invb[i]%MOD)%MOD;
    NTT(b,n<<1,-1); 
    for(int i=0;i<n;i++) b[i+n]=0;
}
L a[M]={0},b[M]={0};
int main(){
    int n,m; scanf("%d%d",&n,&m);
    int nn=1; while(nn<=m) nn<<=1;
    a[0]=1;
    for(int i=1;i<=n;i++){
        int x; scanf("%d",&x);
        if(x<=m) a[x]=(a[x]-4+MOD)%MOD;
    }
    sqrt(a,b,nn); b[0]=(b[0]+1)%MOD;
    memset(a,0,nn<<3);
    getinv(b,a,nn);
    for(int i=1;i<=m;i++) printf("%lld
",a[i]*2%MOD);
}

 

原文地址:https://www.cnblogs.com/xiefengze1/p/9085371.html