【NOIonline2020】T3游戏

树形dp+容斥

(dp[x][i])表示以(x)为根的子树中选了(i)对产生了非平局的匹配,的方案数。

那么这显然可以树形背包做一下,再统计一下(x)与其子树内的点匹配即可。

然后,我们记(f(i)=dp[1][i]cdot(frac{n}{2}-i)!)。那么(f(i))就表示至少(i)局非平局的方案数。(先选出(i)个非平局,剩下(frac{n}{2}-i)局随便排)。

考虑容斥,记(g(i))恰好(i)局非平局的方案数。首先有(g(frac{n}{2})=f(frac{n}{2}))。这很容易产生一个错误的想法,(g(i)=f(i)-sum_{j=i+1}^{frac{n}{2}}g(j))

我们来看一下。假设(frac{n}{2}=3),那么(dp[1][1])会囊括以下情况:

001 010 100

其中,第(i)位为1表示这个位置与其匹配的为非平局。

那么(f(1))呢?发现:

001可以演变成001 011 101 111这4种

010可以演变成010 011 110 111这4种

100可以演变成100 101 110 111这4种

不难看出,011 110 101(即(g(2)))都被算了2次((C_{2}^{1})),而111(即(g(3)))被算了3((C_{3}^{1}))次,所以这就不能直接减了。

那么,我们得出一个新的式子:

[g(i)=f(i)-sum_{j=i+1}^{frac{n}{2}}C_{j}^{i}cdot g(j) ]

然后,此题就解决了。

#include<bits/stdc++.h>
#define MAXN 100010
using namespace std;
const int MOD=998244353;
int tot,n,A[MAXN],head[MAXN],dp[5010][5010],sizeA[MAXN],sizeB[MAXN],tmp[5010],frac[5010],infr[5010],f[5010],ans[5010];
char S[MAXN];
struct node{
	int ed,last;
}G[MAXN<<2];
int Quick_Pow(int a,int p){
	int res=1;
	while(p){
		if(p&1)res=1LL*res*a%MOD;
		a=1LL*a*a%MOD;
		p>>=1;
	}
	return res;
}
void Add(int st,int ed){
	tot++;
	G[tot]=node{ed,head[st]};
	head[st]=tot;
}
void DFS(int x,int fa){
	dp[x][0]=1;
	for(int i=head[x];i;i=G[i].last){
		int t=G[i].ed;
		if(t==fa)continue;
		DFS(t,x);
		for(int j=0;j<=min(sizeA[t],sizeB[t])+min(sizeA[x],sizeB[x]);j++)tmp[j]=0;
		for(int j=0;j<=min(sizeA[t],sizeB[t]);j++)for(int k=0;k<=min(sizeA[x],sizeB[x]);k++)tmp[k+j]=(tmp[k+j]+1LL*dp[t][j]*dp[x][k]%MOD)%MOD;
		for(int j=0;j<=min(sizeA[t],sizeB[t])+min(sizeA[x],sizeB[x]);j++)dp[x][j]=tmp[j];
		sizeA[x]+=sizeA[t],sizeB[x]+=sizeB[t];
	}
	if(S[x]=='0'){
		sizeA[x]++;
		for(int j=min(sizeA[x],sizeB[x]);j>=1;j--)dp[x][j]=(dp[x][j]+1LL*dp[x][j-1]*(sizeB[x]-j+1)%MOD)%MOD;
	}
	else {
		sizeB[x]++;
		for(int j=min(sizeA[x],sizeB[x]);j>=1;j--)dp[x][j]=(dp[x][j]+1LL*dp[x][j-1]*(sizeA[x]-j+1)%MOD)%MOD;
	}
}
int C(int n,int m){
	if(n<m)return 0;
	return 1LL*frac[n]*infr[m]%MOD*infr[n-m]%MOD;
}
int main(){
	frac[0]=1,infr[0]=1;
	for(int i=1;i<=5000;i++)frac[i]=1LL*i*frac[i-1]%MOD,infr[i]=Quick_Pow(frac[i],MOD-2);
	scanf("%d",&n);
	scanf("%s",S+1);
	for(int i=1;i<=n-1;i++){
		int x,y;
		scanf("%d %d",&x,&y);
		Add(x,y);
		Add(y,x);
	}
	DFS(1,0);
	for(int i=n/2;i>=0;i--)f[i]=1LL*dp[1][i]*frac[n/2-i]%MOD;
	for(int i=n/2;i>=0;i--){
		ans[i]=f[i];
		for(int j=i+1;j<=n/2;j++)ans[i]=(ans[i]-1LL*C(j,i)*ans[j]%MOD+MOD)%MOD;
	}
	for(int i=0;i<=n/2;i++)printf("%d
",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/SillyTieT/p/12778957.html