@loj


@description@

本题包含三个问题:

问题 0:已知两棵 n 个结点的树的形态(两棵树的结点标号均为 1~n),其中第一棵树是红树,第二棵树是蓝树。要给予每个结点一个 [1, y] 中的整数,使得对于任意两个节点 p, q,如果存在一条路径 (a1 = p, a2, ..., am = q) 同时属于这两棵树,则 p, q 必须被给予相同的数。求给予数的方案数。

问题 1:已知蓝树,对于红树的所有 (n^{n-2}) 种选择方案,求问题 0 的答案之和。

问题 2:对于蓝树的所有 (n^{n-2}) 种选择方案,求问题 1 的答案之和。

原题请戳我查看qwq

@solution@

说点人话,若两棵树边集的交集为 S,则答案等于 (y^{n - |S|})

前排提醒:下面可能会出现类似 (1 - y) 作分母的情况,当 y = 1 时没有意义。所以需要优先特判掉。
注意 y = 1 时 |S| 并不会影响,所以只取决于有多少种可能的情况。

@问题 0@

相信大家都会做。

@问题 1@

不难想到一个指数级的思路:枚举交集 S,记 f(S) 表示满足要求的树的个数。

交集恰好为 S 显然不好做,而且看起来很好容斥。我们枚举 T,计算交集包含 T 的情况,记为 g(T)。
稍微思考一下得到容斥式子 (f(S) = sum_{Ssubseteq T}(-1)^{|T|-|S|}g(T))

则最终答案有如下式子:

[ans = sum_{S}f(S) imes y^{n - |S|}\ = sum_{S}sum_{Ssubseteq T}(-1)^{|T| - |S|}g(T) imes y^{n - |S|}]

尝试消去 S:

[ans = y^nsum_{T}g(T)sum_{Ssubseteq T}(-1)^{|T|-|S|}y^{-|S|} \ = y^nsum_{T}g(T)sum_{i=0}^{|T|}C_{|T|}^{i}(-1)^{|T|-i}y^{-i}]

用一个二项式定理就可以得到 (ans = y^nsum_{T}g(T)(y^{-1} - 1)^{|T|})
不妨先记 (u = (y^{-1} - 1)),则 (ans = y^nsum_{T}g(T)u^{|T|})

尽管如此还是一个指数级算法。考虑 g(T) 应该怎么求,然后优化成多项式算法。
如果给定边集 T,只要另一棵树中包含 T 中这些边即可。因此相当于先用 T 中的边将 1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数。
用 matrix-tree / prufer 可以证明这个方案数为 (g(T) = n^{k-2} imesprod_{i=1}^{k}a_i)(证明详见下面的补充部分)。

由于 T 中的边连成的连通块个数 (k = n-|T|),所以将原式进一步改写为:

[ans = y^nsum_{T}(n^{k-2} imesprod_{i=1}^{k}a_i imes u^{n-k}) \ = frac{y^n imes u^n}{n^2} imessum_{T}(prod_{i=1}^{k}(a_i imes n imes u^{-1}))]

可以作 O(n^2) 的树形 dp:记 dp[i][j] 表示以 i 为根的子树被分成了若干连通块,其中 i 所在的连通块大小为 j,其他连通块的总贡献为 dp[i][j]。

当然可以更简单:考虑 (a_i imes n imes u^{-1}) 的组合意义。即大小为 (a_i) 的连通块中选择一个,贡献 (n imes u^{-1})
然后记 dp[0/1][i] 表示 i 所在的连通块是否有点贡献了 (n imes u^{-1}),这样子就是 O(n) 的树形 dp 了。

@问题 2@

如果你像我一开始一样,从上面的 dp[0/1][i] 入手,最后就会陷入两个生成函数互相卷积的怪圈中,只能分治 fft O(nlog^2n) 求解。。。

考虑依然是容斥,其它过程都与上面一样,只是 g(T) 的计算式子变为 (g(T) = (n^{k-2} imesprod_{i=1}^{k}a_i)^2)(因为要枚举两棵树嘛)。

那么最终答案 (h[n] = frac{y^n imes u^n}{n^4} imessum_{T}(prod_{i=1}^{k}(a_i^2 imes n^2 imes u^{-1})))

现在枚举 T 反而不好办了。我们考虑直接枚举序列 a,算出有多少边集 T。不妨令点 1 所在的连通块大小为 a1,枚举与点 1 在同一连通块的点得到 h 的转移:

[h[n] = sum_{i=0}^{n-1}C_{n-1}^{i} imes (i+1)^2 imes n^2 imes u^{-1} imes (i+1)^{i-1} imes h[n-i-1] ]

上面那个可以直接 O(n^2) 做了。不过还可以进一步优化:
(p[i] = (i+1)^2 imes n^2 imes u^{-1} imes (i+1)^{i-1}),则上面的卷积又可以写作 (h[n+1] = sum_{i=0}^{n}C_{n}^{i} imes p[i] imes h[n-i])

这是一个经典的卷积式子,可以写成指数型生成函数然后求多项式 exp(具体可见下面的补充部分)。
时间复杂度 O(nlogn)。

@补充部分@

对上面所提到的两个问题的细节补充。

(1)1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数为 (n^{k-2} imesprod_{i=1}^{k}a_i)
证明我选择的是 prufer 序列(懒得写matrix-tree的矩阵证法,网上应该找得到)

由于一个数在 prufer 序列中的出现次数为它的度数减一,又因为从某个大小为 ai 的连通块连出去一条边有 ai 种选择,所以有:

[ans = sum_{sum_{i=1}^{k}d_i = 2k-2}frac{(k-2)!}{prod_{i=1}^{k}(d_i-1)!}prod_{i=1}^{k}a_i^{d_i}\ =prod_{i=1}^{k}a_i imes sum_{sum_{i=1}^{k}(d_i-1) = k-2}frac{(k-2)!}{prod_{i=1}^{k}(d_i-1)!}prod_{i=1}^{k}a_i^{d_i-1}\ =prod_{i=1}^{k}a_i imes (sum_{i=1}^{k}a_i)^{k-2} = prod_{i=1}^{k}a_i imes n^{k-2}]

关于后面那个怎么来的,其实是逆用多项式的展开:

[(x_1 + x_2 + dots + x_n)^k = sum_{sum_{i=1}^{n}a_i = k}(frac{k!}{prod_{i=1}^{n}a_i!}prod_{i=1}^{n}x_i^{a_i}) ]

(2)关于指数型生成函数的 exp 对应的卷积意义。
首先要认识到,对于指数型生成函数而言,积分相等于右移,求导相当于左移。
假如令 (f(x) = sum_{i=0}frac{a_{i}}{i!}x^i),则 (f'(x) = sum_{i=0}frac{a_{i+1}}{i!}x^i)(int f(x) = sum_{i=1}frac{a_{i-1}}{i!}x^i)

根据求导法则,有 (ln(f(x))' = frac{f'(x)}{f(x)}),即 (ln(f(x))' imes f(x) = f'(x))

如果记 (g(x) = ln(f(x))' = sum_{i=0}frac{b_{i}}{i!}x^i),比较第 n 项等式两边的系数,可以得到:

[sum_{i=0}^{n}frac{a_i}{i!} imesfrac{b_{n-i}}{(n-i)!} = frac{a_{n+1}}{n!} ]

然后可以推出 (a_{n+1} = sum_{i=0}^{n}C_n^i imes a_i imes b_{n-i}),就是我们题目中的卷积式子。

@accepted code@

#include <set>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

const int MOD = 998244353;
const int MAXN = 400000;

struct mint{
	int x;
	mint(int _x = 0) : x(_x) {}
	friend mint operator + (mint a, const mint &b) {return (a.x + b.x) % MOD;}
	friend mint operator - (mint a, const mint &b) {return (a.x + MOD - b.x) % MOD;}
	friend mint operator * (mint a, const mint &b) {return 1LL * a.x * b.x % MOD;}
	friend void operator += (mint &a, const mint &b) {a = a + b;}
	friend void operator -= (mint &a, const mint &b) {a = a - b;}
	friend void operator *= (mint &a, const mint &b) {a = a * b;}
	friend mint mpow(mint b, int p) {
		if( b.x == 1 ) return 1;
		mint ret = 1;
		while( p ) {
			if( p & 1 ) ret = ret * b;
			b = b * b;
			p >>= 1;
		}
		return ret;
	}
	friend mint operator / (mint a, const mint &b) {return a * mpow(b, MOD - 2);}
	friend void operator /= (mint &a, const mint &b) {a = a / b;}
};

int n, y, op;

void solve0() {
	if( op == 0 ) printf("%d
", 1);
	else if( op == 1 ) printf("%d
", mpow((mint)n, n - 2).x);
	else printf("%d
", mpow((mint)n, 2*(n - 2)).x);
}

set<pair<int, int> >e;
void solve1() {
	int ans = 0;
	for(int i=1;i<n;i++) {
		int u, v; scanf("%d%d", &u, &v);
		if( u > v ) swap(u, v);
		e.insert(make_pair(u, v));
	}
	for(int i=1;i<n;i++) {
		int u, v; scanf("%d%d", &u, &v);
		if( u > v ) swap(u, v);
		if( e.count(make_pair(u, v)) ) ans++;
	}
	printf("%d
", mpow((mint)y, n - ans).x);
}

struct edge{
	edge *nxt; int to;
}edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt = edges;
void addedge(int u, int v) {
	edge *p = (++ecnt);
	p->to = v, p->nxt = adj[u], adj[u] = p;
	p = (++ecnt);
	p->to = u, p->nxt = adj[v], adj[v] = p;
}
mint dp[2][MAXN + 5], del;
void dfs(int x, int f) {
	dp[0][x] = 1, dp[1][x] = del;
	for(edge *p=adj[x];p;p=p->nxt) {
		if( p->to == f ) continue;
		dfs(p->to, x);
		dp[1][x] = dp[1][x] * dp[1][p->to] + dp[1][x] * dp[0][p->to] + dp[0][x] * dp[1][p->to];
		dp[0][x] = dp[0][x] * dp[1][p->to] + dp[0][x] * dp[0][p->to];
	}
}
void solve2() {
	for(int i=1;i<n;i++) {
		int u, v; scanf("%d%d", &u, &v);
		addedge(u, v);
	}
	mint u = 1; u = (u - y) / y;
	mint p = mpow(y * u, n) / n / n;
	del = n / u, dfs(1, 0);
	printf("%d
", (dp[1][1] * p).x);
}

namespace poly{
	const mint G = 3;
	mint w[20], iw[20], inv[MAXN + 5];
	void init() {
		inv[1] = 1;
		for(int i=2;i<=MAXN;i++)
			inv[i] = MOD - inv[MOD%i]*(MOD/i);
		for(int i=0;i<20;i++)
			w[i] = mpow(G, (MOD-1)/(1<<i)), iw[i] = 1 / w[i];
	}
	void ntt(mint *A, int n, int type) {
		for(int i=0,j=0;i<n;i++) {
			if( i < j ) swap(A[i], A[j]);
			for(int k=(n>>1);(j^=k)<k;k>>=1);
		}
		for(int i=1;(1<<i)<=n;i++) {
			int s = (1 << i), t = (s >> 1);
			mint u = (type == 1 ? w[i] : iw[i]);
			for(int j=0;j<n;j+=s) {
				mint p = 1;
				for(int k=0;k<t;k++,p*=u) {
					mint x = A[j + k], y = A[j + k + t];
					A[j + k] = x + y*p, A[j + k + t] = x - y*p;
				}
			}
		}
		if( type == -1 ) {
			for(int i=0;i<n;i++)
				A[i] *= inv[n];
		}
	}
	mint t1[MAXN + 5], t2[MAXN + 5];
	int length(int n) {
		int l; for(l = 1; l < n; l <<= 1);
		return l;
	}
	void mul(mint *A, int nA, mint *B, int nB, mint *C) {
		int nC = (nA + nB - 1), len = length(nC);
		for(int i=0;i<nA;i++) t1[i] = A[i];
		for(int i=nA;i<len;i++) t1[i] = 0;
		for(int i=0;i<nB;i++) t2[i] = B[i];
		for(int i=nB;i<len;i++) t2[i] = 0;
		ntt(t1, len, 1), ntt(t2, len, 1);
		for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
		ntt(C, len, -1);
	}
	mint t3[MAXN + 5], t4[MAXN + 5];
	void pinv(mint *A, mint *B, int n) {
		if( n == 1 ) {
			B[0] = 1 / A[0];
			return ;
		}
		int m = (n + 1) >> 1;
		pinv(A, B, m);
		int len = length(n << 1);
		for(int i=0;i<m;i++) t3[i] = B[i];
		for(int i=m;i<len;i++) t3[i] = 0;
		for(int i=0;i<n;i++) t4[i] = A[i];
		for(int i=n;i<len;i++) t4[i] = 0;
		ntt(t3, len, 1), ntt(t4, len, 1);
		for(int i=0;i<len;i++)
			B[i] = t3[i] * (2 - t3[i] * t4[i]);
		ntt(B, len, -1);
	}
	void pdif(mint *A, mint *B, int n) {
		for(int i=1;i<n;i++)
			B[i-1] = A[i] * i;
	}
	void pint(mint *A, mint *B, int n) {
		for(int i=n-1;i>=0;i--)
			B[i+1] = A[i] * inv[i + 1];
		B[0] = 0;
	}
	mint t5[MAXN + 5], t6[MAXN + 5];
	void ln(mint *A, mint *B, int n) {
		pdif(A, t5, n), pinv(A, t6, n);
		mul(t5, n - 1, t6, n, t5);
		pint(t5, B, n);
	}
	mint t7[MAXN + 5], t8[MAXN + 5];
	void exp(mint *A, mint *B, int n) {
		if( n == 1 ) {
			B[0] = 1;
			return ;
		}
		int m = (n + 1) >> 1;
		exp(A, B, m);
		for(int i=0;i<m;i++) t7[i] = B[i];
		for(int i=m;i<n;i++) t7[i] = 0;
		ln(t7, t8, n);
		for(int i=0;i<n;i++) t7[i] = A[i] - t8[i];
		t7[0].x += 1;
		for(int i=0;i<m;i++) t8[i] = B[i];
		mul(t7, n, t8, m, B);
	}
}

mint fct[MAXN + 5], ifct[MAXN + 5];
void init() {
	poly::init(); fct[0] = 1;
	for(int i=1;i<=MAXN;i++) fct[i] = fct[i-1] * i;
	ifct[MAXN] = 1 / fct[MAXN];
	for(int i=MAXN-1;i>=0;i--) ifct[i] = ifct[i+1] * (i+1);
}
/*
mint comb(int n, int m) {
	return fct[n] * ifct[m] * ifct[n-m];
}
*/
mint f[MAXN + 5], g[MAXN + 5];
void solve3() {
	init();
	mint u = 1; u = (u - y) / y;
	mint p = mpow(y * u, n) / (mint(n) * n * n * n);
	del = n / u * n;
/*
	for(int i=0;i<n;i++)
		g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1);
	f[0] = 1;
	for(int i=0;i<n;i++)
		for(int j=0;j<=i;j++)
			f[i+1] += comb(i, j)*g[i-j]*f[j];
	printf("%d
", (f[n] * p).x);
*/
	for(int i=0;i<n;i++)
		g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1), g[i] *= ifct[i];
	poly::pint(g, g, n);
	poly::exp(g, f, n + 1);
	printf("%d
", (f[n] * p * fct[n]).x);
}

int main() {
	scanf("%d%d%d", &n, &y, &op);
	if( y == 1 ) solve0();
	else if( op == 0 ) solve1();
	else if( op == 1 ) solve2();
	else if( op == 2 ) solve3();
}

@details@

讲道理,这道题并不算太难分析。

不过可以学到很多分析组合计数的知识与技巧。

原文地址:https://www.cnblogs.com/Tiw-Air-OAO/p/12092766.html