Luogu6478 游戏

Description

link

链接里面的题面清晰易懂 (.jpg)

Solution

恰好 (=) 至少( (or) 至多)(+)二项式反演(或者叫容斥)

那么这个题就转成了求至少 (x) 对祖先方案的数量

(f_{i,j}) 为以 (i) 为根的子树里面有 (j) 对祖先关系的方案数

树上背包:

(1.)(i) 对关系的和有 (j) 对关系的合并

(2.) 顶点和下面的点进行匹配

这两种转移并不难写,统计子树大小,子树黑白点个数即可

(dfs) 就可以完成这个过程

最后记得考虑在根的时候要乘阶乘

因为找祖孙关系是有顺序的

接着上二项式反演把真的值算出来就好的了

关于二项式反演:

[f_x=sum^{n}_{i=x}inom d x g_ iLeftrightarrow g_x=sum_{d=x}^{n}(-1)^{d-x}inom d x f_d ]

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
namespace yspm{
	inline int read()
	{
		int res=0,f=1; char k;
		while(!isdigit(k=getchar())) if(k=='-') f=-1;
		while(isdigit(k)) res=res*10+k-'0',k=getchar();		
		return res*f;
	}
	const int N=5010,mod=998244353;
	int sz1[N],sz0[N],sz[N],fac[N],inv[N],tmp[N],a[N],b[N],f[N][N],n;
	inline int C(int n,int m){return fac[n]*inv[m]%mod*inv[n-m]%mod;}
	vector<int> g[N];
	char s[N];
	inline void dfs(int x,int fa)
	{
		sz[x]=1; f[x][0]=1; int siz=g[x].size();
		for(int i=0;i<siz;++i) 
		{
			int t=g[x][i]; if(t==fa) continue; dfs(t,x);
			for(int j=0;j<=sz[x]+sz[t]+1;++j) tmp[j]=0;
			for(int j=0;j<=sz[x];++j) 
			{
				for(int k=0;k<=sz[t];++k) tmp[j+k]=(tmp[j+k]+f[x][j]*f[t][k])%mod;
			} sz[x]+=sz[t]; sz0[x]+=sz0[t]; sz1[x]+=sz1[t];
			for(int j=0;j<=sz[x];++j) f[x][j]=tmp[j];
		}
		if(s[x]=='0')
		{
			sz0[x]++;
			for(int i=sz1[x]-1;i>=0;--i) f[x][i+1]=(f[x][i+1]+f[x][i]*(sz1[x]-i)%mod)%mod;
		}
		else 
		{
			sz1[x]++;
			for(int i=sz0[x]-1;i>=0;--i) f[x][i+1]=(f[x][i+1]+f[x][i]*(sz0[x]-i)%mod)%mod;
		}return ;
	}
	signed main()
	{
		fac[0]=1; for(int i=1;i<N;++i) fac[i]=fac[i-1]*i%mod;
		inv[0]=inv[1]=1; for(int i=2;i<N;++i) inv[i]=mod-mod/i*inv[mod%i]%mod;
		for(int i=1;i<N;++i) inv[i]=inv[i-1]*inv[i]%mod;
		n=read(); scanf("%s",s+1);
		for(int i=1;i<n;++i)
		{
			int x=read(),y=read();
			g[x].push_back(y);
			g[y].push_back(x);
		} dfs(1,0);
		for(int i=0;i<=n/2;++i) a[i]=fac[n/2-i]*f[1][i]%mod;
		for(int i=0;i<=n/2;++i)
		{
			for(int d=i;d<=n/2;++d) 
			{
				if((d-i)&1) b[i]=(b[i]-C(d,i)*a[d]%mod+mod)%mod;
				else b[i]=(b[i]+C(d,i)*a[d]%mod)%mod;
			}
		}
		for(int i=0;i<=n/2;++i) printf("%lld
",b[i]); 
		return 0;
	}
}
signed main(){return yspm::main();}
原文地址:https://www.cnblogs.com/yspm/p/12945710.html