题解 LOJ#2320. 「清华集训 2017」生成树计数

题目链接

直接考虑这棵树的 Prufer 序列,要求的就是:

[sum_{v_1+v_2+cdots=n-2}(prod_i(v_i+1)^msum_i(v_i+1)^mprod_ia_i^{v_i+1}) ]

然后首先先把一个 (prod a_i) 提到前面去,要算的就是

[(prod a_i)sum_{v_1+v_2+cdots=n-2}(prod_ia_i^{v_i}(v_i+1)^msum_i(v_i+1)^m) ]

后面又有求和号又有连乘号不好写生成函数,于是我们先拆一下,就是:

[(prod a_i)sum_jsum_{v_1+v_2+cdots=n-2}(prod_ia_i^{v_i}(v_i+1)^m)(v_j+1)^m ]

也就是说我们要同时写两个生成函数:

[A(x)=sum_{ngeq 1}(n+1)^mx^n\ B(x)=sum_{ngeq 1}(n+1)^{2m}x^n ]

然后要求的就可以写成:

[(prod a_i)sum_j[x^{n-2}]F_j(x) ]

其中

[F_j(x)=B(a_jx)prod_{i eq j}A(a_ix)\ =frac{B(a_jx)}{A(a_jx)}prod_{i}A(a_ix) ]

然后这个东西直接算不好算,先考虑左半部分,我们令

[G(x)=frac{B(x)}{A(x)} ]

于是我们相当于要对于每个 (i) 求出 (G(a_ix)),这样显然没法做。但是注意到我们求的是所有 (i) 的和,于是这其实是一个等幂和!

对于右半部分,我们可以做 (ln) 之后全部加起来在 (exp) 回去,这个也就是一个等幂和。

复杂度 (mathcal O(nlog^2 n))

#include"poly.h"

// 由于多项式板子太长直接省略。详情可见:
// https://www.cnblogs.com/whx1003/p/14152053.html

#define fi first
#define se second
#define mp(x, y) std::make_pair(x, y)

namespace IdemSum {
	typedef std::pair<poly, poly> ppp;
	inline ppp IdemSum(const vec &A) {
		if(A.size() == 1) return std::make_pair(1, poly({ 1, mod - A[0] }));
		
		int mid = A.size() >> 1;
		ppp L = IdemSum(vec(A.begin(), A.begin() + mid));
		ppp R = IdemSum(vec(A.begin() + mid, A.end()));
		return mp(L.fi.mul(R.se) + R.fi.mul(L.se), L.se.mul(R.se));
	}
}

int n, m; poly A, B, F, G, H;
ll Fac[maxn], Inv[maxn];
ll S; vec a;

inline ll sgn(ll x) { return x & 1 ? mod - 1 : 1; }
int main() {
	n = getll(), m = getll(), S = 1;
	for(int i = 0; i < n; ++i)
		a.push_back(getll()), S = S * a[i] % mod;
	
	Fac[0] = 1;
	for(int i = 1; i <= n; ++i)
		Fac[i] = Fac[i - 1] * i % mod;
	Inv[n] = fsp(Fac[n], mod - 2);
	for(int i = n; i >= 1; --i)
		Inv[i - 1] = Inv[i] * i % mod;
	
	for(int i = 0; i <= n; ++i) A[i] = fsp(i + 1, m) * Inv[i] % mod;
	for(int i = 0; i <= n; ++i) B[i] = fsp(i + 1, 2 * m) * Inv[i] % mod;
	
	F = B * A.inv(), G = A.ln();
	std::pair<poly, poly> con = IdemSum::IdemSum(a);
	H = con.fi.slice(n) * con.se.inv();
	
	for(int i = 0; i <= n; ++i) F[i] = F[i] * H[i] % mod;
	for(int i = 0; i <= n; ++i) G[i] = G[i] * H[i] % mod;
	G = G.exp();
	
	putll((F * G)[n - 2] * S % mod * Fac[n - 2] % mod), IObuf::flush();
}
原文地址:https://www.cnblogs.com/whx1003/p/14152816.html