[WC2019] 数树

Solution

op=0

令两棵树的边集分别为 (S_1,S_2)(T=S_1and S_2),于是原问题等价于形成一个边集为 (T) 的新图。新图上有 (n-|T|) 个连通块,于是答案就是 (y^{n-|T|})

(op=0) 时,直接 (mathcal O(nlog n)) (map)/双指针找出 (|T|) 即可。

namespace sub0{
	pair<int,int> a[N],b[N];
	inline void main(){
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);if(u>v) swap(u,v);
			a[i]=make_pair(u,v);
		}
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);if(u>v) swap(u,v);
			b[i]=make_pair(u,v);
		}
		sort(a+1,a+n);sort(b+1,b+n);
		int ans=0;
		for(int i=1,j=1;i<n;++i){
			while(j<n&&b[j]<a[i]) ++j;
			if(b[j]==a[i]) ++ans;
		}
		printf("%d
",ksm(y,n-ans));
		exit(0);
	}
}

op=1

此时

[egin{aligned} ans&=sum_{S_2}y^{n-|S_1and S_2|}\ &=sum_{T}y^{n-|T|}sum_{S_2}[S_1and S_2=T] end{aligned} ]

后面这个形式不是我们所喜欢的,我们希望转化为形如 (Tsubseteq S_1and S_2) 的形式,这可以利用容斥原理来实现:

[f(S)=sum_{Tsubseteq S}sum_{Psubseteq T}(-1)^{|T|-|P|}f(P) ]

证明考虑计算一个集合 (P) 被计算的次数为 (sum_{i=0}^{|S|-|P|}dbinom{|S|-|P|}{i}(-1)^{i}=(1-1)^{|S|-|P|}=[|S|=|P|]),于是得证。

回到原式子:

[egin{aligned} ans&=sum_{S_2}y^{n-|S_1and S_2|}\ &=sum_{S_2}sum_{Tsubseteq(S_1and S_2)}sum_{Psubseteq T}(-1)^{|T|-|P|}y^{n-|P|}\ &=sum_{Tsubseteq S_1} g(T)sum_{Psubseteq T}(-1)^{|T|-|P|}y^{n-|P|}\ &=sum_{Tsubseteq S_1} g(T)y^{n-|T|}sum_{Psubseteq T}(-1)^{|T|-|P|}y^{|T|-|P|}\ end{aligned} ]

其中 (g(S)) 表示包含边集 (S) 的树的数量。

可以发现其实我们一直只关心 (|T|-|P|) ,因此可以化为:

[egin{aligned} ans&=sum_{Tsubseteq S_1} g(T)y^{n-|T|}sum_{k=0}^{|T|}inom{|T|}{K}(-y)^{k}\ &=sum_{Tsubseteq S_1} g(T)y^{n-|T|}(1-y)^{|T|}\ end{aligned} ]

考虑 (g(T)) 如何计算,设 (T) 中的边使 (n) 个点形成了 (k=n-|T|) 个连通块,第 (i) 个的大小为 (a_i)。于是这个问题等价于在 (k) 个点之间连边形成生成树,第 (i,j) 个点之间有 (a_i imes a_j) 条重边。利用 (matrix-tree) 定理可以得到,(g(T)=n^{k-2}prod_{i=1}^{k}a_i)

于是回到原式子,有

[egin{aligned} ans&=sum_{Tsubseteq S_1} n^{k-2}y^{k}(1-y)^{n-k}prod_{i=1}^{k}a_i \ &=dfrac{(1-y)^{n}}{n^2}sum_{Tin S_1}n^{k}y^{k}(1-y)^{-k}prod_{i=1}^{k}a_i\ &=dfrac{(1-y)^n}{n^2}sum_{Tin S_1}prod_{i=1}^{k}a_idfrac{ny}{(1-y)} end{aligned} ]

(w=dfrac{ny}{1-y}),考虑 (a_iw) 的组合意义,这等架于要求在每个连通块中选择一个点,每选择一个点造成 (w) 的贡献。对此考虑树形 (DP),设 (f_{i,0/1}) 表示仅考虑了 (i) 所在的子树,(i) 所在连通块是否选择了点时的答案,于是可以 (mathcal O(n)) 完成转移。

namespace sub1{
	vector<int> to[N];
	int W,f[N][2];
	inline void dfs(int u,int fa){
		f[u][0]=1;f[u][1]=W;
		for(int v:to[u]){
			if(v==fa) continue;
			dfs(v,u);
			int f0=f[u][0],f1=f[u][1];
			f[u][0]=(1ll*f0*f[v][0]+1ll*f0*f[v][1])%mod;
			f[u][1]=(1ll*f0*f[v][1]+1ll*f1*f[v][0]+1ll*f1*f[v][1])%mod;
		}
	}
	inline void main(){
		if(y==1){printf("%d
",ksm(n,n-2));exit(0);}
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);
			to[u].push_back(v);to[v].push_back(u);
		}	
		W=1ll*n*y%mod*ksm(dec(1,y),mod-2)%mod;
		dfs(1,0);int p=1ll*ksm(dec(1,y),n)*ksm(n,mod-3)%mod;
		printf("%d
",1ll*f[1][1]*p%mod);
		exit(0);
	} 
}

op=2

继续考虑 (op=1) 时的柿子,从枚举 (T)(S_1) 的子集,要求 (T)(S_2) 的子集,改为直接枚举 (T),要求 (T)(S_1,S_2) 的子集,于是有:

[egin{aligned} ans&=sum_{T} g(T)^2y^{n-|T|}(1-y)^{|T|}\ end{aligned} ]

进行与 (op=1) 一模一样的操作后有:

[ans=dfrac{(1-y)^n}{n^4}sum_{T}prod_{i=1}^{k}a_i^2dfrac{n^2y}{(1-y)} ]

此处 (T) 是枚举了所有边集,因此这就相当与考虑了所有划分连通块的方式,原问题等价于将 (n) 个点划分为若干个连通块,一个大小为 (x) 的连通块会产生 (w=x^2dfrac{n^2y}{1-y}) 的贡献,而连通块内部又有 (x^{x-2}) 种方法形成一棵生成树,于是单个连通块的贡献为 (x^xdfrac{n^2y}{1-y})

原问题相当于将 (n) 个点无序划分为若干个连通块,这与城市规划一题的形式一样,因此我们可以同样得到结论,原问题答案的 (EGF) 就是连通块贡献的 (EGF)(exp)

于是使用一遍多项式 (exp) 即可。

Code

#include<bits/stdc++.h>
using namespace std;
const int N=(1<<18)+20;
const int mod=998244353;
int n,y,op;
inline void inc(int &x,int y){x=(x+y>=mod)?x+y-mod:x+y;}
inline int dec(int x,int y){return (x-y<0)?x-y+mod:x-y;}
inline int ksm(int x,int y){
	int ret=1;
	for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) ret=1ll*ret*x%mod;
	return ret;
}
namespace sub0{
	pair<int,int> a[N],b[N];
	inline void main(){
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);if(u>v) swap(u,v);
			a[i]=make_pair(u,v);
		}
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);if(u>v) swap(u,v);
			b[i]=make_pair(u,v);
		}
		sort(a+1,a+n);sort(b+1,b+n);
		int ans=0;
		for(int i=1,j=1;i<n;++i){
			while(j<n&&b[j]<a[i]) ++j;
			if(b[j]==a[i]) ++ans;
		}
		printf("%d
",ksm(y,n-ans));
		exit(0);
	}
}
namespace sub1{
	vector<int> to[N];
	int W,f[N][2];
	inline void dfs(int u,int fa){
		f[u][0]=1;f[u][1]=W;
		for(int v:to[u]){
			if(v==fa) continue;
			dfs(v,u);
			int f0=f[u][0],f1=f[u][1];
			f[u][0]=(1ll*f0*f[v][0]+1ll*f0*f[v][1])%mod;
			f[u][1]=(1ll*f0*f[v][1]+1ll*f1*f[v][0]+1ll*f1*f[v][1])%mod;
		}
	}
	inline void main(){
		if(y==1){printf("%d
",ksm(n,n-2));exit(0);}
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);
			to[u].push_back(v);to[v].push_back(u);
		}	
		W=1ll*n*y%mod*ksm(dec(1,y),mod-2)%mod;
		dfs(1,0);int p=1ll*ksm(dec(1,y),n)*ksm(n,mod-3)%mod;
		printf("%d
",1ll*f[1][1]*p%mod);
		exit(0);
	} 
}

namespace sub2{
	typedef vector<int> vec;
	typedef unsigned long long ull;
	int iv[N],tp,fac[N],jc[N];
	inline void init_inv(int n){
		if(!tp){iv[0]=iv[1]=fac[0]=fac[1]=jc[0]=jc[1]=1;tp=2;}
		for(;tp<=n;++tp){
			iv[tp]=1ll*(mod-mod/tp)*iv[mod%tp]%mod;
			fac[tp]=1ll*fac[tp-1]*tp%mod;
			jc[tp]=1ll*jc[tp-1]*iv[tp]%mod;
		}
	}		
	struct poly{
		vec v;
		inline poly(int w=0):v(1){v[0]=w;}
		inline poly(const vec&w):v(w){}
			
		inline int operator [](int x)const{return x>=v.size()?0:v[x];}
		inline int& operator [](int x){if(x>=v.size()) v.resize(x+1);return v[x];}
		inline int size(){return v.size();}
		inline void resize(int x){v.resize(x);} 
		
		inline poly slice(int len)const{
			if(len<=v.size()) return vec(v.begin(),v.begin()+len);
			vec ret(v);ret.resize(len);
			return ret;
		}
		inline poly operator *(const int &x)const{
			poly ret(v);
			for(int i=0;i<v.size();++i) ret[i]=1ll*ret[i]*x%mod; 
			return ret;
		}
	};
	
	
	int Wn[N<<1],lg[N],r[N],tot;
	inline void init_poly(int n){
		int p=1;while(p<=n)p<<=1;
		for(int i=2;i<=p;++i) lg[i]=lg[i>>1]+1;
		for(int i=1;i<p;i<<=1){
			int wn=ksm(3,(mod-1)/(i<<1));
			Wn[++tot]=1;
			for(int j=1;j<i;++j) ++tot,Wn[tot]=1ll*Wn[tot-1]*wn%mod;
		}
	}
	inline void init_pos(int lim){
		int len=lg[lim]-1;
		for(int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len);
	}
	
	ull fr[N];
	const ull Mod=998244353;
	inline void NTT(int *f,int lim,int tp){
		for(int i=0;i<lim;++i) fr[i]=f[r[i]];
		for(int mid=1;mid<lim;mid<<=1){
			for(int len=mid<<1,l=0;l+len-1<lim;l+=len){
				for(int k=l;k<l+mid;++k){
					ull w1=fr[k],w2=fr[k+mid]*Wn[mid+k-l]%Mod;
					fr[k]=w1+w2;fr[k+mid]=w1+Mod-w2; 
				}
			}
		}
		for(int i=0;i<lim;++i) f[i]=fr[i]%Mod;
		if(!tp){
			reverse(f+1,f+lim);
			int iv=ksm(lim,mod-2);
			for(int i=0;i<lim;++i) f[i]=1ll*f[i]*iv%mod;
		}
	}
	inline poly to_poly(int *a,int n){
		poly ret;
		ret.resize(n);
		memcpy(ret.v.data(),a,n<<2);
		return ret;
	}
	namespace Exp{
		const int logB=4;
		const int B=16;
		int f[N],ret[N],H[N];
		poly g[4][B];
		inline void exp(int lim,int l,int r,int dep){
			if(r-l<=128){
				for(int i=l;i<r;++i){
					ret[i]=(!i)?1:1ll*ret[i]*iv[i]%mod;
					for(int j=i+1;j<r;++j)	
						inc(ret[j],1ll*ret[i]*f[j-i]%mod);
				}
				return ;
			}
			int k=(r-l)/B;
			int len=1<<lim-logB+1;
			vector<unsigned long long> bl[B];
			for(int i=0;i<B;++i) bl[i].resize(k<<1);
			for(int i=0;i<B;++i){
				if(i>0){
					init_pos(len);
					for(int j=0;j<(k<<1);++j) H[j]=bl[i][j]%mod; 
					NTT(H,len,0);
					for(int j=0;j<k;++j)
						inc(ret[l+i*k+j],H[j+k]);
				}
				exp(lim-logB,l+i*k,l+(i+1)*k,dep+1);
				if(i<B-1){
					memcpy(H,ret+l+i*k,sizeof(int)*(k));
					memset(H+k,0,sizeof(int)*(k));
					init_pos(len);NTT(H,len,1);
					for(int j=i+1;j<B;++j)
						for(int t=0;t<(k<<1);++t) 
							bl[j][t]+=1ll*H[t]*g[dep][j-i-1][t]; 
				}
			}
		}
		
		inline poly getexp(poly F,int n){
			memcpy(f,F.v.data(),sizeof(int)*(n));
			int mx=lg[n]+1;init_inv(1<<mx);
			for(int i=0;i<n;++i) f[i]=1ll*f[i]*i%mod;
			memset(ret,0,sizeof(int)*(1<<mx));
			for(int lim=mx,dep=0;lim>=8;lim-=logB,dep++){
				int len=1<<(lim-logB+1);
				init_pos(len);
				for(int i=0;i<B-1;++i){
					g[dep][i].resize(len);
					memcpy(g[dep][i].v.data(),f+(len>>1)*i,sizeof(int)*(len));
					NTT(g[dep][i].v.data(),len,1);
				}
			}
			exp(mx,0,1<<mx,0);
			return to_poly(ret,n);
		}
	}
	inline void main(){
		if(y==1){printf("%d
",ksm(n,2*n-4));exit(0);}
		init_poly((n+1)<<1);init_inv(n);
		poly f;
		int p=1ll*n*n%mod*y%mod*ksm(dec(1,y),mod-2)%mod; 
		for(int i=0;i<=n;++i) f[i]=1ll*ksm(i,i)*p%mod*jc[i]%mod;
		f=Exp::getexp(f,n+1);
		printf("%d
",1ll*f[n]*fac[n]%mod*ksm(dec(1,y),n)%mod*ksm(n,mod-5)%mod);
	}	
}
int main(){
	scanf("%d%d%d",&n,&y,&op);
	if(!op) sub0::main();
	else if(op==1) sub1::main();
	else sub2::main();
	return 0;
}
原文地址:https://www.cnblogs.com/tqxboomzero/p/15235284.html