@loj


@description@

在一个 s 个点的图中,存在 s - n 条边,使图中形成了 n 个连通块,第 i 个连通块中有 (a_i) 个点。

现在我们需要再连接 n - 1 条边,使该图变成一棵树。对一种连边方案,设原图中第 i 个连通块连出了 (d_i) 条边,那么这棵树 T 的价值为:

[val(T) = (prod_{i=1}^{n}d_{i}^{m})(sum_{i=1}^{n}d_{i}^{m}) ]

你的任务是求出所有可能的生成树的价值之和,对 998244353 取模。

原题戳我

@solution@

@正文@

注意到 (d_i) 为度数,那么考虑 prufer 序列,直接写出答案表达式:

[ans = sum_{(sum_{i=1}^{n}b_i)=n-2}(frac{(n-2)!}{prod_{i=1}^{n}b_i!}) imes(prod_{i=1}^{n}a_{i}^{b_i + 1}) imes(prod_{i=1}^{n}(b_{i} + 1)^{m}) imes(sum_{i=1}^{n}(b_{i} + 1)^{m}) ]

其中 (b_i + 1 = d_i)

作一些简单的变形:

[ans = (n-2)! imes(prod_{i=1}^{n}a_i) imessum_{i=1}^{n}sum_{(sum_{j=1}^{n}b_j)=n-2}(frac{(b_{i} + 1)^{2m} imes a_{i}^{b_{i}}}{b_{i}!}) imes(prod_{j=1,j ot =i}^{n}frac{(b_{j} + 1)^{2m} imes a_{j}^{b_{j}}}{b_{j}!}) ]

引入生成函数。如果记 (P(x) = sum_{i=0}frac{(i + 1)^{2m} imes x^i}{i!})(Q(x) = sum_{i=0}frac{(i + 1)^{m} imes x^i}{i!}),则:

[ans = (n-2)! imes(prod_{i=1}^{n}a_i) imes([x^{n-2}]sum_{i=1}^{n}P(a_i x) imes(prod_{j=1,j ot =i}^{n}Q(a_j x)))\ ans = (n-2)! imes(prod_{i=1}^{n}a_i) imes([x^{n-2}]prod_{i=1}^{n}Q(a_i x) imessum_{i=1}^{n}frac{P(a_i x)}{Q(a_i x)})]

注意到 (frac{P(a_i x)}{Q(a_i x)}) 其实就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (a_i^{k})

也就是说 (sum_{i=1}^{n}frac{P(a_i x)}{Q(a_i x)}) 就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (sum_{i=1}^{n}a_i^{k}),而 (sum_{i=1}^{n}a_i^{k}) 是可以快速求出的(在补充部分介绍)。

尝试把 (prod_{i=1}^{n}Q(a_i x)) 也化成加法形式:利用对数,可以得到 (prod_{i=1}^{n}Q(a_i x) = exp(sum_{i=1}^{n}ln(Q(a_i x))))

之后就没有了。只要求出了 (sum_{i=1}^{n}a_i^{k}),剩下的都是模板。

@补充@

关于如何求 (sum_{i=1}^{n}a_i^{k}),其实方法比较多,这里介绍一种:

注意到 (ln(1 - x) = -sum_{i=1}frac{x^i}{i}),那么只要求出 (sum_{i=1}^{n}ln(1 - a_ix)),也就求出了 (sum_{i=1}^{n}a_i^{k})

利用对数的性质,有 (sum_{i=1}^{n}ln(1 - a_ix) = ln(prod_{i=1}^{n}(1 - a_ix)))

然后里面那个式子分治 fft 可以 O(nlog^2n) 搞定,这样一来总时间复杂度其实就是 O(nlog^2n)。

@accepted code@

#include <cstdio>
#include <algorithm>
using namespace std;

const int MAXN = 4*30000;
const int MOD = 998244353;

struct mint{
	int x;
	mint(int _x=0) : x(_x) {}
	friend mint operator + (mint a, mint b) {
		return a.x + b.x >= MOD ? a.x + b.x - MOD : a.x + b.x;
	}
	friend mint operator - (mint a, mint b) {
		return a.x - b.x < 0 ? a.x - b.x + MOD : a.x - b.x;
	}
	friend mint operator * (mint a, mint b) {
		return (int)(1LL * a.x * b.x % MOD);
	}
	friend mint pow_mod(mint b, int p) {
		mint ret = 1;
		while( p ) {
			if( p & 1 ) ret *= b;
			b *= b;
			p >>= 1;
		}
		return ret;
	}
	friend mint operator / (mint a, mint b) {
		return a * pow_mod(b, MOD - 2);
	}
	friend void operator += (mint &a, mint b) {a = a + b;}
	friend void operator -= (mint &a, mint b) {a = a - b;}
	friend void operator *= (mint &a, mint b) {a = a * b;}
	friend void operator /= (mint &a, mint b) {a = a / b;}
};

namespace poly{
	const mint G = 3;	
	mint w[20], iw[20], inv[MAXN + 5];
	void init() {
		for(int i=0;i<20;i++) {
			w[i] = pow_mod(G, (MOD-1)/(1<<i));
			iw[i] = pow_mod(w[i], MOD-2);
		}
		inv[1] = 1;
		for(int i=2;i<=MAXN;i++)
			inv[i] = 0 - (MOD/i)*inv[MOD%i];
	}
	void debug(mint *A, int n) {
		for(int i=0;i<n;i++)
			printf("%d ", A[i].x);
		puts("");
	}
	void ntt(mint *A, int n, int type) {
		for(int i=0,j=0;i<n;i++) {
			if( i < j ) swap(A[i], A[j]);
			for(int k=(n>>1);(j^=k)<k;k>>=1);
		}
		for(int i=1;(1<<i)<=n;i++) {
			int s = (1 << i), t = (s >> 1);
			mint u = (type == 1 ? w[i] : iw[i]);
			for(int j=0;j<n;j+=s) {
				mint p = 1;
				for(int k=0;k<t;k++,p*=u) {
					mint x = A[j+k], y = A[j+k+t];
					A[j+k] = x + y*p, A[j+k+t] = x - y*p;
				}
			}
		}
		if( type == -1 ) {
			mint iv = inv[n];
			for(int i=0;i<n;i++)
				A[i] *= iv;
		}
	}
	int length(int n) {
		int len; for(len = 1; len < n; len <<= 1);
		return len;
	}
	void pcopy(mint *A, mint *B, int n, int l) {
		for(int i=0;i<n;i++) A[i] = B[i];
		for(int i=n;i<l;i++) A[i] = 0;
	}
	mint t1[MAXN + 5], t2[MAXN + 5];
	void pmul(mint *A, int nA, mint *B, int nB, mint *C) {
		int len = length(nA + nB - 1);
		pcopy(t1, A, nA, len), ntt(t1, len, 1);
		pcopy(t2, B, nB, len), ntt(t2, len, 1);
		for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
		ntt(C, len, -1);
	}
	mint t3[MAXN + 5], t4[MAXN + 5];
	void pinv(mint *A, mint *B, int n) {
		if( n == 1 ) {
			B[0] = 1 / A[0];
			return ;
		}
		int m = (n + 1) >> 1; pinv(A, B, m);
		int len = length(n << 1);
		pcopy(t3, A, n, len), ntt(t3, len, 1);
		pcopy(t4, B, m, len), ntt(t4, len, 1);
		for(int i=0;i<len;i++) B[i] = t4[i]*(2 - t3[i]*t4[i]);
		ntt(B, len, -1);
	}
	void pdif(mint *A, mint *B, int n) {
		for(int i=1;i<n;i++)
			B[i-1] = A[i] * i;
	}
	void pint(mint *A, mint *B, int n) {
		for(int i=n-1;i>=0;i--)
			B[i+1] = A[i] / (i + 1);
		B[0] = 0;
	}
	mint t5[MAXN + 5], t6[MAXN + 5];
	void pln(mint *A, mint *B, int n) {
		pinv(A, t5, n), pdif(A, t6, n);
		pmul(t5, n, t6, n, B);
		pint(B, B, n);
	}
	mint t7[MAXN + 5], t8[MAXN + 5];
	void pexp(mint *A, mint *B, int n) {
		if( n == 1 ) {
			B[0] = 1;
			return ;
		}
		int m = (n + 1) >> 1; pexp(A, B, m);
		int len = length(n << 1);
		pcopy(t7, B, m, len), pln(t7, t8, n), pcopy(t7, t8, n, len);
		pcopy(t8, B, m, len);
		for(int i=0;i<n;i++) t7[i] = A[i] - t7[i];
		t7[0] = t7[0] + 1;
		ntt(t7, len, 1), ntt(t8, len, 1);
		for(int i=0;i<len;i++) B[i] = t7[i] * t8[i];
		ntt(B, len, -1);
	}
}

int n, m, k;

mint A[MAXN + 5], B[MAXN + 5];
void init() {
	mint t = 1;
	for(int i=0;i<n;i++,t*=i) {
		mint a = 1 / t, b = pow_mod(mint(i + 1), m);
		A[i] = a * b * b, B[i] = a * b;
	}
	poly::init();
}

mint a[MAXN + 5], f[MAXN + 5], s[MAXN + 5];
int solve(int l, int r) {
	if( l == r ) {
		f[l<<1] = 1, f[l<<1|1] = 0 - a[l];
		return 2;
	}
	int mid = (l + r) >> 1;
	int ls = solve(l, mid), rs = solve(mid + 1, r);
	poly::pmul(f + (l<<1), ls, f + ((mid + 1) << 1), rs, f + (l << 1));
	return ls + rs - 1;
}
void get_pow_sum() {
	solve(0, n - 1), poly::pln(f, s, n + 1);
	s[0] = n;
	for(int i=1;i<=n;i++)
		s[i] = 0 - s[i]*i;
}

mint t1[MAXN + 5], t2[MAXN + 5];
int main() {
	scanf("%d%d", &n, &m), k = n - 2, init();
	for(int i=0;i<n;i++) scanf("%d", &a[i].x);
	
	get_pow_sum();
	poly::pln(B, t1, n);
	for(int i=0;i<n;i++)
		t1[i] *= s[i];
	poly::pexp(t1, t2, n);
	poly::pinv(B, t1, n);
	poly::pmul(A, n, t1, n, t1);
	for(int i=0;i<n;i++)
		t1[i] *= s[i];
	poly::pmul(t1, n, t2, n, t1);
	mint ans = t1[n - 2];
	for(int i=0;i<n;i++) ans *= a[i];
	for(int i=1;i<=n-2;i++) ans *= i;
	printf("%d
", ans.x);
}

@details@

顺带一提,这道题还有依赖于斯特林数的 O(nmlogn) 的做法(但是我看不懂 QaQ)。

原文地址:https://www.cnblogs.com/Tiw-Air-OAO/p/12119237.html