loj#2983. 「WC2019」数树

题目描述


题解

至少相比一年以前想到了拆y^i,只不过没想到提y^n出来而已(确信)

op=0

块=点-边,hash

op=1

假设一棵红树的块数为j,则贡献为y^j*方案数

方案数直接用prufer算(n^{m-2}prod a_i)会算重,会连上蓝树的边

套路:恰好=-1后的至少

问题是直接把(y-1+1)展开会发现顺序反了,有兴趣可以自己尝试

对于一个树S,设|S|为边数,则贡献为(y^{n-|S|})

把y^n提出来,设z=1/y,变成每选一条边就乘上(z-1),此时把(z-1+1)^j二项式展开后就是对的了

具体来说就是当i<j时,算z^i的多选边的方案会把j算C(j,i)次

于是可以设f[i][j]表示当前到子树i,块大小为j的系数和,O(n^2)

瓶颈在于求(prod a_i),等价于每个连通块里选恰好一个点

所以设(f[i][0/1])表示当前连通块内是否有选点,随便转移即可O(n)

op=2

先提y^n,op=1的展开仍然可以用

枚举树S,贡献为(sum_{S}(z-1)^{|S|}*(方案)^2)

边不好搞,设块m=n-|S|,则原式为(sum_S (z-1)^{n-m}*(n^{m-2}prod a_i)^2)

直接枚举块大小ai,那么原式变为(sum (z-1)^{n-m}*n^{2m-4}*prod a_i^2*prod a_i^{a_i-2}*C(n,...)/块数!),后面的是块内的方案,除以块数!把块变为无序

提掉n和常数,可以设f[i]表示n个点放了i个的答案

(f[i]=sum f[i-j]*(z-1)^{-1}*n^2*j^j*inom{i-1}{j-1}),即枚举最后一个点所在块

这个用分治ntt可以做到log^2,也可以换一种做法:

设OGF(F_i=(z-1)^{-1}*n^2*i^i)(G(x))(F(x))的EGF,则([x^n]e^{G(x)})就是答案

因为(e^{G(x)}=sum frac{G^i(x)}{i!}),即枚举段数再除阶乘变为无序

常数略大

注意EGF和exp是两种不同的东西,并且EGF中的i!是形式,最后一定要乘回去变为OGF

code

exp调试方法:

内部的数组可以用static,不要用namespace

https://www.cnblogs.com/gmh77/p/13166794.html#%E8%B0%83%E8%AF%95%E6%96%B9%E6%B3%95%E6%B3%A8%E6%84%8F%E4%BA%8B%E9%A1%B9

#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define min(a,b) (a<b?a:b)
#define max(a,b) (a>b?a:b)
#define mod 998244353
#define Mod 998244351
#define G 3
#define ll long long
//#define file
using namespace std;

struct graph{
	int a[200001][2],ls[100001],len;
	void New(int x,int y) {++len;a[len][0]=y;a[len][1]=ls[x];ls[x]=len;}
} gr;
ll dp[100001][2],jc[262144],Jc[262144],w[262144],y,z,ans;
int n,op,i,j,k,l,Len,N,len;

ll qpower(ll a,int b) {ll ans=1; while (b) {if (b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans;}
void swap(int &x,int &y) {int z=x;x=y;y=z;}

void work0()
{
	static map<pair<int,int>,bool> hs;
	static map<pair<int,int>,bool> :: iterator I;
	
	fo(i,1,n-1)
	{
		scanf("%d%d",&j,&k);
		if (j>k) swap(j,k);
		hs[pair<int,int>(j,k)]=1;
	}
	fo(i,1,n-1)
	{
		scanf("%d%d",&j,&k);
		if (j>k) swap(j,k);
		I=hs.find(pair<int,int>(j,k));
		if (I!=hs.end()) ++ans;
	}
	ans=qpower(y,n-ans);
}

void dfs(int Fa,int t)
{
	ll x,y;
	int i;
	
	dp[t][0]=1;
	for (i=gr.ls[t]; i; i=gr.a[i][1])
	if (gr.a[i][0]!=Fa)
	{
		dfs(t,gr.a[i][0]);
		x=(dp[t][0]*dp[gr.a[i][0]][1]%mod*n+dp[t][0]*dp[gr.a[i][0]][0]%mod*z)%mod;
		y=(dp[t][1]*dp[gr.a[i][0]][1]%mod*n+(dp[t][1]*dp[gr.a[i][0]][0]+dp[t][0]*dp[gr.a[i][0]][1])%mod*z)%mod;
		dp[t][0]=x,dp[t][1]=y;
	}
	dp[t][1]=(dp[t][1]+dp[t][0])%mod;
}
void work1()
{
	fo(i,1,n-1) scanf("%d%d",&j,&k),gr.New(j,k),gr.New(k,j);
	dfs(0,1);
	ans=dp[1][1]*qpower(n,Mod)%mod*qpower(y,n)%mod;
}

void init()
{
	jc[0]=jc[1]=Jc[0]=Jc[1]=w[1]=1;
	fo(i,2,262143) w[i]=mod-w[mod%i]*(mod/i)%mod,jc[i]=jc[i-1]*i%mod,Jc[i]=Jc[i-1]*w[i]%mod;
}
void dft(ll *a,int tp,int N,int len)
{
	static ll A[262144];
	int i,j,k,l,S=N,s1=2,s2=1;
	ll u,v,w,W;
	
	fo(i,0,N-1)
	{
		j=i,k=0;
		fo(l,1,len)
		k=k*2+(j&1),j>>=1;
		A[i]=a[k];
	}
	memcpy(a,A,N*8);
	
	fo(i,1,len)
	{
		W=(tp==1)?qpower(G,(mod-1)/s1):qpower(G,(mod-1)-(mod-1)/s1);
		S>>=1;
		fo(j,0,S-1)
		{
			w=1;
			fo(k,0,s2-1)
			{
				u=a[j*s1+k],v=a[j*s1+k+s2]*w;
				a[j*s1+k]=(u+v)%mod;
				a[j*s1+k+s2]=(u-v)%mod;
				w=w*W%mod;
			}
		}
		s1<<=1,s2<<=1;
	}
}
void mul(ll *a,ll *b,ll *c,int N,int len)
{
	static ll A[262144],B[262144];
	int i,N2=qpower(N,Mod);
	
	memset(A,0,N*8),memset(B,0,N*8);
	fo(i,0,N-1) A[i]=a[i],B[i]=b[i];
	dft(A,1,N,len),dft(B,1,N,len);
	fo(i,0,N-1) A[i]=A[i]*B[i]%mod;
	dft(A,-1,N,len);
	fo(i,0,N-1) c[i]=A[i]*N2%mod;
}
void ny(ll *a,ll *b,int N,int len)
{
	static ll A[262144],c[262144];
	int i;
	memset(b,0,N*8);
	if (N==1) {b[0]=qpower(a[0],Mod);return;}
	ny(a,b,N/2,len-1);
	
	memset(c,0,N*8*2);
	mul(b,b,c,N,len);
	memset(A,0,N*8*2),memcpy(A,a,N*8);
	mul(c,A,c,N*2,len+1);
	fo(i,0,N-1) b[i]=(2*b[i]-c[i])%mod;
}
void dao(ll *a,ll *b,int N,int len)
{
	int i;
	fo(i,0,N-2) b[i]=a[i+1]*(i+1)%mod;b[N-1]=0;
}
void ji(ll *a,ll *b,int N,int len)
{
	int i;
	fd(i,N-1,1) b[i]=a[i-1]*w[i]%mod;b[0]=0;
}
void Ln(ll *a,ll *b,int N,int len)
{
	static ll A[262144],B[262144];
	int i;
	
	memset(A,0,N*8*2),memset(B,0,N*8*2);
	dao(a,A,N,len),ny(a,B,N,len);
	mul(A,B,b,N*2,len+1);
	ji(b,b,N,len);
}
void Exp(ll *a,ll *b,int N,int len)
{
	static ll A[262144];
	int i;
	memset(b,0,N*8*2);
	if (N==1) {b[0]=1;return;}
	Exp(a,b,N/2,len-1);
	
	memset(A,0,N*8*2);
	Ln(b,A,N,len);
	fo(i,N,N+N-1) A[i]=0;
	fo(i,0,N-1) A[i]=(-A[i]+a[i])%mod;++A[0];
	mul(A,b,b,N*2,len+1);
}
void work2()
{
	static ll f[262144],F[262144];
	
	fo(i,1,n)
	f[i]=qpower(z,Mod)*n%mod*n%mod*qpower(i,i)%mod*Jc[i]%mod;len=ceil(log2(n+1)),N=qpower(2,len);
	Exp(f,F,N,len);
	ans=(F[n]*jc[n]%mod)*qpower(z,n)%mod*qpower(qpower(n,Mod),4)%mod*qpower(y,n)%mod; //EGF->OGF
}

int main()
{
	#ifdef file
	freopen("loj2983.in","r",stdin);
	#endif
	
	init();
	scanf("%d%lld%d",&n,&y,&op),z=qpower(y,Mod)-1;
	if (y==1)
	{
		switch (op)
		{
			case 0:{printf("%lld
",1);break;}
			case 1:{printf("%lld
",qpower(n,n-2));break;}
			case 2:{printf("%lld
",qpower(n,2*(n-2)));break;}
		}
		return 0;
	}
	
	switch (op)
	{
		case 0:{work0();break;}
		case 1:{work1();break;}
		case 2:{work2();break;}
	}
	printf("%lld
",(ans+mod)%mod);
	
	fclose(stdin);
	fclose(stdout);
	return 0;
}
原文地址:https://www.cnblogs.com/gmh77/p/13336920.html