CF438E The Child and Binary Tree(生成函数,NTT)

题目链接:洛谷 CF原网

题目大意:有 $n$ 个互不相同的正整数 $c_i$。问对于每一个 $1le ile m$,有多少个不同形态(考虑结构和点权)的二叉树满足每个点权都在 $c$ 中出现过,且点权和为 $i$。答案对 $998244353$ 取模。

$1le n,mle 10^5$。


首先考虑DP,$f_i$ 表示点权和为 $i$ 的树数。

那么枚举根节点的点权和两棵子树的点权和 $f_k=sumlimits^n_{i=1}c_isumlimits^{k-c_i}_{j=0}f_jf_{k-c_i-j}$。

初始状态 $f_0=1$。因为空树也能作为子树。

这样的复杂度是 $O(nm^2)$,不能过。

考虑 $c$ 的生成函数 $C(x)=sum x^{c_i}$ 和 $f$ 的生成函数 $F(x)=sum f_ix^i$。(你问我怎么想到的?我也不知道啊)

那么容易发现原来的式子就是几个函数的卷积。

$F=C imes F imes F+1$(注意 $f_0=1$)

$C imes F^2-F+1=0$

$F=dfrac{1pmsqrt{1-4C}}{2C}$

接下来看看上面该取正还是负。

取正时 $limlimits_{x ightarrow 0}F(x)=+infty$,不收敛,舍去。

取负时 $limlimits_{x ightarrow 0}F(x)=1$,符合题意。

那么 $F=dfrac{1-sqrt{1-4C}}{2C}=dfrac{2}{1+sqrt{1-4C}}$。

直接套模板即可。时间复杂度 $O(mlog m)$。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=333333,mod=998244353;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define ROF(i,a,b) for(int i=(a);i>=(b);i--)
#define MEM(x,v) memset(x,v,sizeof(x))
inline int read(){
    char ch=getchar();int x=0,f=0;
    while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
int n,m,c[maxn],lim,l,rev[maxn],invtmp[maxn],Binv[maxn],sqrtmp[maxn],Csqrt[maxn],Cinv[maxn];
inline void init(int upr){
    for(lim=1,l=0;lim<upr;lim<<=1,l++);
    FOR(i,0,lim-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
}
inline int add(int a,int b){return a+b<mod?a+b:a+b-mod;}
inline int sub(int a,int b){return a<b?a-b+mod:a-b;}
inline int qpow(int a,int b){
    int ans=1;
    for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod;
    return ans;
}
void NTT(int *A,int tp){
    FOR(i,0,lim-1) if(i<rev[i]) swap(A[i],A[rev[i]]);
    for(int i=1;i<lim;i<<=1)
        for(int j=0,r=i<<1,Wn=qpow(3,mod-1+tp*(mod-1)/r);j<lim;j+=r)
            for(int k=0,w=1;k<i;k++,w=1ll*w*Wn%mod){
                int x=A[j+k],y=1ll*A[i+j+k]*w%mod;
                A[j+k]=add(x,y);A[i+j+k]=sub(x,y);
            }
    if(tp==-1) for(int i=0,linv=qpow(lim,mod-2);i<lim;i++) A[i]=1ll*A[i]*linv%mod;
}
void poly_inv(int *A,int *B,int deg){
    if(deg==1) return void(B[0]=qpow(A[0],mod-2));
    poly_inv(A,B,(deg+1)>>1);
    init(deg<<1);
    FOR(i,0,deg-1) invtmp[i]=A[i];
    FOR(i,deg,lim-1) invtmp[i]=0;
    NTT(invtmp,1);NTT(B,1);
    FOR(i,0,lim-1) B[i]=1ll*sub(2,1ll*invtmp[i]*B[i]%mod)*B[i]%mod;
    NTT(B,-1);
    FOR(i,deg,lim-1) B[i]=0;
}
void poly_sqrt(int *A,int *B,int deg){
    if(deg==1) return void(B[0]=1);
    poly_sqrt(A,B,(deg+1)>>1);
    init(deg<<1);
    FOR(i,0,lim-1) Binv[i]=0;
    poly_inv(B,Binv,deg);
    init(deg<<1);
    FOR(i,0,deg-1) sqrtmp[i]=A[i];
    FOR(i,deg,lim-1) Binv[i]=sqrtmp[i]=0;
    NTT(sqrtmp,1);NTT(Binv,1);
    FOR(i,0,lim-1) sqrtmp[i]=1ll*sqrtmp[i]*Binv[i]%mod;
    NTT(sqrtmp,-1);
    FOR(i,0,deg-1) B[i]=499122177ll*add(B[i],sqrtmp[i])%mod;
    FOR(i,deg,lim-1) B[i]=0;
}
int main(){
    n=read();m=read();
    FOR(i,1,n){
        int x=read();
        if(x<=m) c[x]=1;
    }
    FOR(i,1,m) c[i]=(mod-4ll*c[i]%mod)%mod;
    c[0]=1;
    poly_sqrt(c,Csqrt,m+1);
    Csqrt[0]=add(Csqrt[0],1);
    poly_inv(Csqrt,Cinv,m+1);
    FOR(i,1,m) printf("%d
",add(Cinv[i],Cinv[i]));
}
View Code
原文地址:https://www.cnblogs.com/1000Suns/p/10424492.html