[BZOJ3625] [Codeforces Round #250]小朋友和二叉树

Description

我们的小朋友很喜欢计算机科学,而且尤其喜欢二叉树。
考虑一个含有n个互异正整数的序列c[1],c[2],...,c[n]。如果一棵带点权的有根二叉树满足其所有顶点的权值都在集合{c[1],c[2],...,c[n]}中,我们的小朋友就会将其称作神犇的。并且他认为,一棵带点权的树的权值,是其所有顶点权值的总和。
给出一个整数m,你能对于任意的s(1<=s<=m)计算出权值为s的神犇二叉树的个数吗?请参照样例以更好的理解什么样的两棵二叉树会被视为不同的。
我们只需要知道答案关于998244353(7172^23+1,一个质数)取模后的值。

Input

第一行有2个整数 n,m(1<=n<=10^5; 1<=m<=10^5)。
第二行有n个用空格隔开的互异的整数 c[1],c[2],...,c[n](1<=c[i]<=10^5)。

Output

输出m行,每行有一个整数。第i行应当含有权值恰为i的神犇二叉树的总数。请输出答案关于998244353(=7172^23+1,一个质数)取模后的结果。

Sample Input

2 3
1 2

Sample Output

1
3
9

Solution

话说多项式的代码是真的难写...

思路其实比较简单,设(f(n))表示花(n)的代价得到的二叉树的个数,(g(n))表示有没有代价为(n)的点,只能为(0,1)

那么很简单得到(dp)方程:

[f(n)=sum_{i=1}^ng(i)sum_{j=1}^{n-i}f(j)f(n-i-j) ]

这里是枚举根节点填什么,然后两边分别怎么填。

其中(f)的边界为(f(0)=1)

利用直觉生成函数法可得,把(f,g)写成生成函数的形式可得:

[F(x)=sum_{n=0}^{infty} f(n)x^n,G(x)=sum_{n=0}^{infty} g(n)x^n ]

然后把(dp)方程写成卷积形式:

[F(x)=F^2(x)G(x)+1 ]

(+1)是表示第(0)项。

解得:

[F(x)=frac{2}{1pmsqrt{1-4G(x)}} ]

注意到([0]G(x)=0),取(-)分母为(0),所以取正号:

[F(x)=frac{2}{1+sqrt{1-4G(x)}} ]

然后就是多项式求逆和开根的板子了。

注意求逆和开根不要用同样的数组,不然各种冲突...我调了好久...

#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}

#define lf double
#define ll long long 

const int maxn = 2e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
const int inv2 = 499122177;

int f[maxn],g[maxn],n,m,mxn,bit,N,w[maxn],rw[maxn],s[maxn],t[maxn],pos[maxn];

int qpow(int a,int x) {
	int res=1;
	for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
	return res;
}

void prepare() {
	w[0]=1;w[1]=qpow(3,(mod-1)/mxn);
	for(int i=2;i<=mxn;i++) w[i]=1ll*w[i-1]*w[1]%mod;
	rw[0]=1,rw[1]=qpow(qpow(3,mod-2),(mod-1)/mxn);
	for(int i=2;i<=mxn;i++) rw[i]=1ll*rw[i-1]*rw[1]%mod;
}

void ntt(int *r,int op) {
	for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
	for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1) 
		for(int j=0;j<N;j+=i<<1)
			for(int k=0;k<i;k++) {
				int x=r[j+k],y=1ll*r[i+j+k]*(op==1?w:rw)[k*d]%mod;
				r[j+k]=(x+y)%mod,r[i+j+k]=(x-y+mod)%mod;
				}
	if(op==-1) {
		int inv=qpow(N,mod-2);
		for(int i=0;i<N;i++) r[i]=1ll*r[i]*inv%mod;
	}
}

int tmp1[maxn],tmp2[maxn],tmp3[maxn];

void get_pos(int len) {
	for(bit=0,N=1;N<len;N<<=1,bit++);
	for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}

void poly_inv(int *r,int *b,int len) {
	if(len==1) return b[0]=qpow(r[0],mod-2),void();
	poly_inv(r,b,len>>1);
	for(int i=0;i<len;i++) tmp1[i]=b[i],tmp2[i]=r[i];
	get_pos(len<<1);
	ntt(tmp1,1),ntt(tmp2,1);
	for(int i=0;i<N;i++) b[i]=((2ll*tmp1[i]%mod-1ll*tmp2[i]*tmp1[i]%mod*tmp1[i]%mod)%mod+mod)%mod;
	ntt(b,-1);
	for(int i=len;i<N;i++) b[i]=0;
	for(int i=0;i<len<<1;i++) tmp1[i]=tmp2[i]=0;
}

void poly_sqrt(int *r,int *b,int len) {
	if(len==1) return b[0]=r[0],void();
	poly_sqrt(r,b,len>>1);
	poly_inv(b,tmp3,len);
	get_pos(len<<1);
	for(int i=0;i<len;i++) tmp2[i]=r[i];
	ntt(tmp2,1),ntt(tmp3,1);
	for(int i=0;i<N;i++) tmp3[i]=1ll*tmp3[i]*tmp2[i]%mod;
	ntt(tmp3,-1);
	for(int i=0;i<len;i++) b[i]=1ll*inv2*(b[i]+tmp3[i])%mod;
	for(int i=0;i<len<<1;i++) tmp3[i]=tmp2[i]=0;
}

int main() {
	read(n),read(m);
	for(int i=1,x;i<=n;i++) read(x),x<=m?g[x]=1:0;
	for(mxn=1;mxn<=m<<1;mxn<<=1);
	prepare();
	for(int i=1;i<=m;i++) g[i]=(mod-4*g[i])%mod;
	g[0]=(g[0]+1)%mod;
	poly_sqrt(g,s,mxn>>1);s[0]=(s[0]+1)%mod;
	poly_inv(s,t,mxn>>1);
	for(int i=1;i<=m;i++) write((2ll*t[i]%mod+mod)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/hbyer/p/10563556.html