「TJOI / HEOI2016」求和 的一个优秀线性做法

我们把(S(i, j)j!)看成是把(i)个球每次选择一些球(不能为空)扔掉,选(j)次后把所有球都扔掉的情况数(顺序有关)。因此(S(i, j)j! = i![x^i](e^x - 1)^j)

为了求出答案,我们需要研究如下的生成函数的性质。

(P(x) = sum_{i = 0}^{n}(2e^x - 2)^i = sum_{i = 0}^{n} 2^i sum_{j = 0}^{i} (-1)^{i - j}e^{jx} {i choose j} = sum_{j = 0}^{n} e^{jx}sum_{i = j}^{n} 2^i(-1)^{i - j} {i choose j})

(a_j = sum_{i = j}^{n} (-2)^i {i choose j})。在线性时间内计算(a_j)是个经典的问题。

(a_0)是很容易计算的。

(j ge 1)时:

(a_j)

(= sum_{i = j}^{n} (-2)^i ({i - 1 choose j} + {i - 1 choose j - 1}))

(= -2sum_{i = j}^{n - 1} (-2)^i{i choose j} -2sum_{i = j - 1}^{n - 1} (-2)^i{i choose j - 1})

(= -2a_j + 2(-2)^{n} {n choose j} - 2a_{j - 1} + 2(-2)^{n} {n choose j - 1})

转换为递推式(a_j = frac{1}{3} (2(-2)^n {n choose j} + 2(-2)^n{n choose j - 1} - 2a_{j - 1}))

欲求的答案就是(sum_{j = 0}^{n} (-1)^ja_j sum_{i = 0}^{n} i![x^i]e^{jx})

我们发现答案就是(sum_{i = 0}^{n} i![x^i]e^{jx} = sum_{i = 0}^{n} j^i),可以使用等比数列求和公式计算。

我们需要计算(j^{n + 1}),这可以先计算出(j)为素数处的取值,然后再用线性筛算出(1 leq j leq n)时的取值。复杂度变成了(O(frac{n}{ln n} cdot log_2{n}) = O(n))

于是,我们在(O(n))的时间内做出了本题。顺便获得目前的rk1.

代码如下:

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " " << (x) << endl
using namespace std;

const int N = 100005;
const long long mod = 998244353ll;

int n, pri[N], cnt = 0;
bool is_pri[N];
long long pw1[N], pw2[N], inv[N], binom[N], a[N], ans = 0ll;

long long qpow (long long a, long long b) {
	long long res = 1ll;
	for (; b; b >>= 1, a = a * a % mod) {
		if (b & 1) res = res * a % mod;
	}
	return res;
}

void init () {
	pw1[1] = pw2[0] = inv[1] = 1ll;
	for (int i = 1; i <= max(n, 3); i++) is_pri[i] = (i != 1), pw2[i] = 2ll * (mod - pw2[i - 1]) % mod;
	for (int i = 2; i <= max(n, 3); i++) {
		inv[i] = (mod / i) * (mod - inv[mod % i]) % mod;
		if (is_pri[i]) pw1[i] = qpow(i, n + 1), pri[cnt++] = i;
		for (int j = 0; j < cnt && i * pri[j] <= n; j++) {
			is_pri[i * pri[j]] = false;
			pw1[i * pri[j]] = pw1[i] * pw1[pri[j]] % mod;
			if (i % pri[j] == 0) break;
		}
	}
	binom[0] = 1ll;
	for (int i = 1; i <= n; i++) binom[i] = binom[i - 1] * (n - i + 1) % mod * inv[i] % mod;
}

int main () {
	scanf("%d", &n), init();

	a[0] = 0ll;
	for (int i = 0; i <= n; i++) a[0] = (a[0] + pw2[i]) % mod;
	for (int i = 1; i <= n; i++) {
		a[i] = pw2[n] * (binom[i] + binom[i - 1]) % mod;
		a[i] = (a[i] - a[i - 1] + mod) % mod;
		a[i] = 2ll * a[i] % mod * inv[3] % mod;
	}

	for (int i = 0; i <= n; i++) {
		if (!i) ans = (ans + a[i]) % mod;
		else if (i == 1) ans = (ans + mod * mod - a[i] * (n + 1)) % mod;
		else if (i & 1) ans = (ans + mod * mod - a[i] * (pw1[i] + mod - 1) % mod * inv[i - 1]) % mod;
		else ans = (ans + a[i] * (pw1[i] + mod - 1) % mod * inv[i - 1]) % mod;
	}
	printf("%lld
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/mathematician/p/12693500.html