[WC2019]数树

Problem

问题 0:已知两棵 (n) 个节点的树的形态(两棵树的节点标号均为 (1)(n)),其中第一棵树是红树,第二棵树是蓝树。要给予每个节点一个([1,y])中的整数,使得对于任意两个节点 (p,q),如果存在一条路径 ((a_1=p,a_2,cdots,a_m=q)) 同时属于这两棵树,则 (p,q) 必须被给予相同的数。求给予数的方案数。

问题 1:已知蓝树,对于红树的所有 (n^{n-2}) 种选择方案,求问题 0 的答案之和。

问题 2:对于蓝树的所有 (n^{n-2}) 种选择方案,求问题 1 的答案之和。

Algorithm

矩阵树定理容斥DP生成函数计数

Solution

问题 0 :有零点难度

如果 ((i,j)) 边在红树和蓝树都出现,那么 (i,j) 两个结点必须钦定相同的数字,缩到一块。最后就是 (y^{n-|E_1 cap E_2|})

问题 1 :有一点难度

首先对于蓝树,令其边集为(E)。考察某一种红树,其边集为(E'),则 (E_0=E cap E') 形成了许多连通块,每个连通块的选择方案均独立,贡献为 (y^{n-|E_0|})

反向考虑有多少种红树,使得 (E_0) 相等。这里的图 (G=<V,E_0>) 形成了 (m=n-|E_0|) 个连通块,块内的边已经连好了,而块与块之间的边未定。假设这些连通块大小为 (a_1,a_2,cdots)(i)(j) 两个连通块之间有 (a_ia_j) 条边。利用矩阵树定理求解问题。最后我们需要求出下面矩阵的行列式

[left| egin{matrix} a_1(n-a_1)&-a_1a_2&-a_1a_3&cdots&-a_1a_{m-1}\ -a_2a_1&a_2(n-a_2)&-a_2a_3&cdots&-a_2a_{m-1}\ -a_3a_1&-a_3a_2&a_3(n-a_3)&cdots&-a_3a_{m-1}\ vdots&vdots&vdots&ddots&vdots\ -a_{m-1}a_1&-a_{m-1}a_2&-a_{m-1}a_3&cdots&a_{m-1}(n-a_{m-1})\ end{matrix} ight|]

将第 2 行至第 m-1 行元素加到第一行上

[left| egin{matrix} a_1a_n&a_2a_n&a_3a_n&cdots&a_{m-1}a_n\ -a_2a_1&a_2(n-a_2)&-a_2a_3&cdots&-a_2a_{m-1}\ -a_3a_1&-a_3a_2&a_3(n-a_3)&cdots&-a_3a_{m-1}\ vdots&vdots&vdots&ddots&vdots\ -a_{m-1}a_1&-a_{m-1}a_2&-a_{m-1}a_3&cdots&a_{m-1}(n-a_{m-1})\ end{matrix} ight|]

第 2 行至第 m-1 行加上第 1 行乘上(dfrac{a_i}{a_n})

[left| egin{matrix} a_1a_n&a_2a_n&a_3a_n&cdots&a_{m-1}a_n\ 0&a_2n&0&cdots&0\ 0&0&a_3n&cdots&0\ vdots&vdots&vdots&ddots&vdots\ 0&0&0&cdots&a_{m-1}n\ end{matrix} ight|]

最终得到

[Ans=n^{m-2}prod_{1leq ileq m}a_i ]

可是得到的生成树中新连接的某些边又可能在 (E-E_0) 中。考虑容斥。利用容斥的定义得到

[Ans(E_0)=sum_{E_0subseteq E_1subseteq E}(-1)^{|E_1|-|E_0|}T(E_1) ]

其中 (T(E_1)) 表示图 (G=<V,E_1>) 中对连通块间生成树的计数。

(E_0) 又能产生贡献 (y^{n-|E_0|})。所以最后需要求解

[egin{aligned} Ans&=sum_{E_0subseteq E}y^{n-|E_0|}sum_{E_0subseteq E_1subseteq E}(-1)^{|E_1|-|E_0|}T(E_1)\ &=sum_{E_1subseteq E}(-1)^{|E_1|}y^{n-|E_1|}T(E_1)sum_{E_0subseteq E_1}(-1)^{|E_0|}y^{|E_1|-|E_0|}\ &=sum_{E_1subseteq E}(-1)^{|E_1|}y^{n-|E_1|}T(E_1)(y-1)^{|E_1|}\ &=sum_{E_1subseteq E}y^{n-|E_1|}(1-y)^{|E_1|}T(E_1)\ &=sum_{m=1}^ny^m(1-y)^{n-m}f_m end{aligned} ]

目标就是求解 (f_m),即蓝树上选m个连通块下,红树的方案数

考虑组合意义:相当于选取 (m) 个连通块,并在每个连通块内选择一个代表点。由于每个连通块内有 (size) 个代表点,故贡献刚刚好为乘积和。于是我们可以DP了。

如果对于每个结点记录划分了多少连通块,复杂度只能做到 (mathcal O(n^2))(树上背包)。但发现发现式子写成这样

[(1-y)^nsum_{m=1}^nleft(frac{y}{1-y} ight)^mf_m ]

这样每次增加一个连通块时只需把答案乘上 (frac{y}{1-y} imes n) 就行了!(乘上 (n) 由于前面推导的 (n^{m-2}prod a_i)

(dp_{u,0/1}) 表示到 (u) 结点时,有/没有选择代表点,令增加连通块的时刻定义与设置代表点等价,转移如下

[egin{cases} dp_{u,0}=dp'_{u,0} imes dp_{v,0}+dp'_{u,0} imes dp_{v,1}\ dp_{u,1}=dp'_{u,0} imes dp_{v,1}+dp'_{u,1} imes dp_{v,0}+dp'_{u,1} imes dp_{v,1}\ end{cases} ]

初始化 (dp_{u,0}=1,dp_{u,1}=frac{ny}{1-y})

所以答案为

[Ans=frac{(1-y)^n}{n^2} imes dp_{root,1} ]

问题 2 :有两点难度

还是看这个式子

[(1-y)^nsum_{m=1}^nleft(frac{y}{1-y} ight)^mf_m ]

这里 (f_m) 应该重新定义为:两棵树形态任意,边的交形成的连通块大小为m,连通块大小的乘积和。这里采用生成函数计数。

构造生成函数

[F(x)=sum_{i>0}frac{i^2 imes i^{i-2}x^i}{i!} ]

([x^i]F(x)) 表示选取一个大小为 (i) 的连通块时,内部任意连边,并且附带两次连通块大小乘积的贡献。两次的原因在后面;

[G(x)=sum_{i>0}left(frac{y}{1-y} ight)^ifrac{n^{i-2} imes n^{i-2}F^i(x)}{i!} ]

表示枚举任意数量的连通块产生的贡献数。

这里可以解释贡献两次的原因了:第一次是对这些连通块形成蓝树的方案数连乘因子的贡献;第二次是形成红树的方案数连乘因子的贡献,两者独立。

所以答案即为

[Ans=[x^n]G(x) ]

然后化简 (G(x))

[G(x)=frac1{n^4}sum_{i>0}frac{left(dfrac{n^2y}{1-y}f(x) ight)^i}{i!}=frac1{n^4}e^{F'(x)} ]

其中

[F'(x)=frac{n^2y}{1-y}f(x)=sum_{i>0}frac{n^2y}{1-y}frac{i^ix^i}{i!} ]

多项式 exp 即可。

复杂度(mathcal O(nlog n))

最后注意(y=1)的情况。

#include <bits/stdc++.h>
const int N = 100005, P = 998244353;
using std::vector;
typedef vector<int> Poly;
typedef long long LL;
int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int pow(int a, int b) {
	int t = 1;
	for (; b; b >>= 1, a = 1LL*a*a%P)
		if (b & 1) t = 1LL*t*a%P;
	return t;
}
int W[N*4], inv[N*4], fac[N*4], ifac[N*4];
void prework(int n) {
	for (int i = 1; i < n; i <<= 1)
		for (int j = W[i] = 1, Wn = pow(3, (P-1)/i>>1); j < i; j++)
			W[i+j] = 1LL*W[i+j-1]*Wn%P;
	inv[1] = fac[0] = ifac[0] = 1;
	for (int i = 2; i <= n; i++) inv[i] = 1LL*(P-P/i)*inv[P%i]%P;
	for (int i = 1; i <= n; i++) fac[i] = 1LL*fac[i-1]*i%P, ifac[i] = 1LL*ifac[i-1]*inv[i]%P;
}
void fft(Poly &a, int n, int opt) {
	a.resize(n);
	static int rev[N*4];
	for (int i = 1; i < n; i++)
		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0, t; i < q; i++)
				t = 1LL*a[p+q+i]*W[q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
	if (opt) return;
	for (int i = 0, inv = pow(n, P-2); i < n; i++) a[i] = 1LL*a[i]*inv%P;
	std::reverse(a.begin()+1, a.end());
}
Poly poly_inv(Poly A) {
	Poly B(1, pow(A[0], P-2)), C(2);
	for (int L = 1; L < A.size(); L <<= 1) {
		for (int i = 0; i < L*2; i++) C[i] = i < A.size() ? A[i] : 0;
		fft(B, L*4, 1), fft(C, L*4, 1);
		for (int i = 0; i < L*4; i++) B[i] = (2*B[i]-1LL*B[i]*B[i]%P*C[i]%P+P)%P;
		fft(B, L*4, 0); B.resize(L*2);
	}
	return B.resize(A.size()), B;
}
void fix(Poly &A) { int t = A.size(); while (t > 1 && !A[t-1]) t--; A.resize(t); }
int getsz(int n) { int x = 1; while (x < n) x <<= 1; return x; }
Poly operator + (Poly A, Poly B) {
	A.resize(std::max(A.size(), B.size()));
	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], B[i]);
	return fix(A), A;
}
Poly operator - (Poly A, Poly B) {
	A.resize(std::max(A.size(), B.size()));
	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], P-B[i]);
	return fix(A), A;
}
Poly operator * (Poly A, Poly B) {
	int L = getsz(A.size()+B.size()-1);
	fft(A, L, 1), fft(B, L, 1);
	for (int i = 0; i < L; i++) A[i] = 1LL*A[i]*B[i]%P;
	return fft(A, L, 0), fix(A), A;
}
Poly poly_deri(Poly A) {
	for (int i = 0; i < A.size()-1; i++) A[i] = 1LL*(i+1)*A[i+1]%P;
	return A.resize(A.size()-1), A;
}
Poly poly_int(Poly A) {
	for (int i = A.size()-1; i; i--) A[i] = 1LL*inv[i]*A[i-1]%P;
	return A[0] = 0, A;
}
Poly poly_ln(Poly A) {
	Poly B = poly_deri(A) * poly_inv(A);
	return B.resize(A.size()), poly_int(B);
}
Poly poly_exp(Poly A) {
	Poly B(1, 1), C;
	for (int L = 1; L < A.size(); L <<= 1)
		B.resize(L*2), C = A + Poly(1, 1) - poly_ln(B), C.resize(L*2), B = B*C;
	return B.resize(A.size()), B;
}
int n, y, op, f[N][2], k;
vector<int> e[N];
void dp(int u, int fa) {
	f[u][0] = 1, f[u][1] = k;
	for (int i = 0, v; i < e[u].size(); i++)
		if ((v = e[u][i]) != fa) {
			dp(v, u);
			int f0 = f[u][0], f1 = f[u][1];
			f[u][0] = ((LL)f0*f[v][0] + (LL)f0*f[v][1]) % P;
			f[u][1] = ((LL)f0*f[v][1] + (LL)f1*f[v][0] + (LL)f1*f[v][1]) % P;
		}
}
int main() {
	scanf("%d%d%d", &n, &y, &op);
	if (y == 1) { printf("%d", op == 0 ? 1 : (op == 1 ? pow(n, n-2) : 1LL*pow(n, n-2)*pow(n, n-2)%P)); return 0; }
	if (op == 0) {
		for (int i = 1; i < 2*n-1; i++) {
			int u, v; scanf("%d%d", &u, &v);
			e[u].push_back(v), e[v].push_back(u);
		}
		int ans = 0;
		for (int i = 1; i <= n; i++) {
			std::sort(e[i].begin(), e[i].end()); ans += e[i].size();
			ans -= std::unique(e[i].begin(), e[i].end()) - e[i].begin();
		} ans /= 2;
		printf("%d", pow(y, n-ans));
	} else {
		k = 1LL*n*y%P*pow(P+1-y, P-2)%P;
		if (op == 1) {
			for (int i = 1; i < n; i++) {
				int u, v; scanf("%d%d", &u, &v);
				e[u].push_back(v), e[v].push_back(u);
			}
			dp(1, 0);
			printf("%d", 1LL*f[1][1]*pow(P+1-y, n)%P*pow(1LL*n*n%P, P-2)%P);
		} else {
			prework((n+1)*2);
			Poly F(n+1);
			for (int i = 1; i <= n; i++)
				F[i] = 1LL*n*k%P*pow(i, i)%P*ifac[i]%P;
			F = poly_exp(F);
			printf("%d", 1LL*F[n]*fac[n]%P*pow(P+1-y, n)%P*pow(pow(n, 4), P-2)%P);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ac-evil/p/14175048.html