NOI Online #2 提高组 游戏

没用二项式反演的菜比。

题目链接

Solution

非平局代表的树上祖先关系是比较好统计,(可以在处理一个点时,考虑用他去匹配他的子树中的东西)而平局的关系比较难统计。我们不妨求出至少 (k) 个祖先关系的方案数,接着用容斥原理得到恰好 (k) 个祖先关系的方案数。

求出至少 k 个祖先关系的方案数

状态设计

(f_{u, i}) 为以 (u) 为根的子树中,已经有了 (i) 对相互配对的祖先点的方案数。

状态转移

  • 这个状态是由 (u) 的每个儿子与 (u) 的影响共同作用的,不妨先把 (u) 的每个儿子都合并起来,最后考虑 (u) 对状态的影响。

  • 设之前已经合并过子树的数组是 (f_u),当前的儿子是 (v),我们只需合并 (f_u)(f_v) 即可,即 (f_{u, i + j} gets sum f_{u, i} imes f_{v, j}),可以理解为一个树上背包,体积是相互配对的祖先点对数,价值的方案数。

  • 考虑 (u) 的作用,让 (u) 和子树中的一个没有被匹配的点匹配,设 (u) 的颜色是 (C_u in {0, 1})。设 (u) 的子树中 (0 / 1) 颜色的点数量是 (Size_{u, 0/1})(!C_u) 表示和 (C_u) 的颜色相反。那么 (f_{u, i} gets f_{u, i - 1} imes (Size_{u, !C_u} - (i - 1)))(这一步很妙,无需记录别的信息就可以直接让 (u) 做到匹配,原因是通过已经匹配对数以及我们处理的树上信息可以计算出能匹配的点数)

答案

(g_i) 为整棵树中,至少有 (k) 对祖先关系(钦定了 (k) 对,剩下的不知道是不是祖先)的方案数,有 (g_i = C_m^i imes f_{1, i}),即先选出 (i) 对,然后剩下的自由匹配。

注意这里的至少和我们常接触的那个“至少”可能不太一样,具体看后面容斥的解释。

DP 的时间复杂度

第一个合并乍一看是 (O(n^3)) 的,但其实如果卡 (i) 的上限,发现 (i) 这一维不会超过 (Size_u)(以 (u) 为根的子树大小),那么一次合并的时间是 (Size_{之前子树的和} imes Size_v),相当于两块间两两的有序点对,那么两个点只会在 (LCA) 带来一次合并的贡献,所以是 (O(n^2)) 的。

用容斥原理得到恰好 k 个祖先

(Ans_i) 为整棵树中,恰好有 (k) 对祖先关系的方案数,即我们的答案。

发现这个 (g) 有点奇怪,就比如说我恰好有三组匹配 (Ans_3 = 1)(1、2、3) (三个祖先匹配关系的编号没有实际意义),但 (g_2) 可以有 (1、2) 被钦定,(3) 随机对成祖先、可以有 (1、3) 被钦定,(2) 随机对成祖先、可以有 (2、3) 被钦定,(1) 随机对成祖先三种。换句话说,(Ans_3)(g_2) 的贡献是 (C_3^2 imes Ans_3),总结一下:

可以把 (g_i) 的所有方案数划分成这样的类别:恰好有 (i, i + 1, i + 2, ...m),而对于一个恰好为 (j ge i) 组的一个方案,他对 (g_i) 的贡献有 (C_j^i),可以理解为从 (j) 个里面选出 (i) 个,作为 (g_i) 中实际选定的那 (i) 个,即:

(g_i = sum_{j=i}^{m} (m-i)! imes Ans_j)

由于 (C_i^i = 1),将如上式子进行变换,可以得到一个关于 (Ans_i) 的表达式:

(Ans_i = g_i - sum_{j=i+1}^m C_j^i imes Ans_j)

即我们只需要知道 (g_i)(m ge j > i)(Ans_j),就能算出 (Ans_i)

那么从大到小推一遍,(Ans) 就出来了。

总时间复杂度

(O(n^2))

Code

代码中实现没有再新建 (g)(Ans) 数组,而是直接在 (f_{1}) 数组上操作。

#include <iostream>
#include <cstdio>

using namespace std;

const int N = 5005, P = 998244353;

typedef long long LL;


// sz[i] 表示以 i 为根子树的大小,cnt[i][0 / 1] 与文中的 size[i][0 / 1] 意义相同
int n, m, f[N][N], sz[N], cnt[N][2], tmp[N];

int fact[N], infact[N];

char s[N];

int head[N], numE = 0;

struct E{
	int next, v;
} e[N << 1];

void inline add(int u, int v) {
	e[++numE] = (E) { head[u], v };
	head[u] = numE;
}

int inline power(int a, int b) {
	int res = 1;
	while (b) {
		if (b & 1) res = (LL)res * a % P;
		a = (LL)a * a % P;
		b >>= 1;
	}
	return res;
}

void dfs(int u, int fa) {
	f[u][0] = 1;
	for (int i = head[u]; i; i = e[i].next) {
		int v = e[i].v;
		if (v == fa) continue;
		dfs(v, u);
		for (int j = 0; j <= sz[u] + sz[v]; j++) tmp[j] = 0;
		for (int j = 0; j <= sz[u]; j++)
			for (int k = 0; k <= sz[v]; k++)
				tmp[j + k] = (tmp[j + k] + (LL)f[u][j] * f[v][k]) % P;
		for (int j = 0; j <= sz[u] + sz[v]; j++) f[u][j] = tmp[j];
		sz[u] += sz[v], cnt[u][0] += cnt[v][0], cnt[u][1] += cnt[v][1];
	}
	sz[u]++;
	if (s[u] == '0') cnt[u][0]++;
	else cnt[u][1]++;
	for (int i = sz[u]; i; i--)
		f[u][i] = (f[u][i] + (LL)f[u][i - 1] * (cnt[u][s[u] == '0' ? 1 : 0] - (i - 1))) % P;
}

int inline C(int a, int b) {
	return (LL)fact[a] * infact[b] % P * infact[a - b] % P;
}

int main() {
	scanf("%d%s", &n, s + 1);
	m = n / 2;
	fact[0] = infact[0] = 1;
	for (int i = 1; i <= m; i++) fact[i] = (LL)fact[i - 1] * i % P;
	infact[m] = power(fact[m], P - 2);
	for (int i = m - 1; i; i--) infact[i] = (LL)infact[i + 1] * (i + 1) % P;
	for (int i = 1, u, v; i < n; i++)
		scanf("%d%d", &u, &v), add(u, v), add(v, u);
	dfs(1, 0);
	for (int i = 0; i <= m; i++) f[1][i] = (LL)f[1][i] * fact[m - i] % P;  
	for (int i = m; ~i; i--) 
		for (int j = i + 1; j <= m; j++) 
			f[1][i] = ((f[1][i] - (LL)f[1][j] * C(j, i)) % P + P) % P;
	for (int i = 0; i <= m; i++) printf("%d
", f[1][i]);
	return 0;
}
原文地址:https://www.cnblogs.com/dmoransky/p/12788193.html