首先,题目中的过程可以看作:每次选择任意一个燃料仓,给它装填 (1) 单位的燃料,如果此时恰好 “填满” 了它,就给答案 (+1)。
考虑 (n) 号燃料仓填满的概率,因为所有燃料仓是等价的,由期望线性性,答案就是这个概率乘 (n)。
填满 (n) 号燃料仓前,我们必定给它装填了 (1) 单位。考虑这之前的状态:前 (n - 1) 个燃料仓中至少有一个装填了少于 (b) 单位,第 (n) 个燃料仓恰好装填了 (a - 1) 单位。所以说,(n) 号仓被填满概率就是:
可以把 ([min{x_1, x_2, cdots, x_{n - 1}} < b]) 转化成 (1 - [x_1, x_2, cdots, x_{n - 1} ge b])。考虑写成 EGF 的形式:
我们要求的即是 (sum_{i ge 0} hat{f}_i cdot i!)。
考虑换元,令 (u = e^{frac{1}{n}x}, v = (frac{x}{n})),那么有:
先假设 (hat{f}(x) = sum f_{p, q} u^p v^q),我们对每一项分别考虑:
这一项对答案的贡献是:
所以,我们已经可以快速求 (f_{p, q} u^p v^q) 对答案的贡献了,现在考虑如何求 (f_{p, q})。
令 (sum_{i = 0}^{b - 1} frac{v^i}{i!} = P),(P) 是关于 (v) 的 (b - 1) 次多项式。二项式定理展开,我们有:
我们只要算出 (P^1, P^2, cdots, P^{n - 1}) 的各项系数即可。
暴力算是 (O(n^2b^2)) 的,可以用 FFT 优化到 (O(n^2b log nb)),下面讲一个 (O(n^2b)) 的方法。
发现 (P' = P - frac{v^{b - 1}}{(b - 1)!}),考虑微分方程:
其中 (Q) 是 (P^{k - 1}) 乘上一个单项式。我们按照 (k) 从小到大的顺序递推,假设我们已经求出了当前的 (Q)。设 (P^k = sum_{i = 0}^{m} p_iv^i, Q = sum_{i = 0}^{m} q_iv^i),那么:
我们还有 (p_0 = 1),所以就可以直接递推了。总共的时间复杂度为 (O(n^2b))。
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= int(b); i++)
#define per(i, a, b) for (int i = (a); i >= int(b); i--)
using namespace std;
const int maxn = 250, maxm = maxn * maxn, mod = 998244353;
int n, a, b, fact[maxm + 3], finv[maxm + 3], inv[maxm + 3], m, p[maxm + 3], q[maxm + 3];
int qpow(int a, int b) {
int c = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod) if (b & 1) c = 1ll * a * c % mod;
return c;
}
void prework(int n) {
fact[0] = 1;
rep(i, 1, n) fact[i] = 1ll * fact[i - 1] * i % mod;
finv[n] = qpow(fact[n], mod - 2);
per(i, n, 1) finv[i - 1] = 1ll * finv[i] * i % mod;
rep(i, 1, n) inv[i] = 1ll * fact[i - 1] * finv[i] % mod;
}
int C(int n, int m) {
return 1ll * fact[n] * finv[m] % mod * finv[n - m] % mod;
}
int main() {
scanf("%d %d %d", &n, &a, &b);
prework(max(a, n * b));
p[0] = 1;
int res = 0;
rep(k, 1, n - 1) {
rep(i, 0, m) q[i + b - 1] = 1ll * p[i] * finv[b - 1] % mod;
m += b - 1;
rep(i, 0, m - 1) p[i + 1] = 1ll * inv[i + 1] * k % mod * (p[i] - q[i] + mod) % mod;
int num = qpow((1 - 1ll * (n - 1 - k) * inv[n] % mod + mod) % mod, mod - 2);
int cur = 1ll * (k & 1 ? 1 : mod - 1) * C(n - 1, k) % mod * qpow(1ll * inv[n] * num % mod, a - 1) % mod * fact[a - 1] % mod;
rep(i, 0, m) {
cur = 1ll * cur * (i == 0 ? 1 : inv[n]) % mod * (i == 0 ? 1 : i + a - 1) % mod * num % mod;
res = (res + 1ll * p[i] * cur) % mod;
}
}
res = 1ll * res * finv[a - 1] % mod;
printf("%d
", res);
return 0;
}