「PKUWC2018」猎人杀

题目

好题好题

这个分母一直在变,看上去完全不知道怎么去算

考虑容斥一波,设(g_i)表示至少有(i)死在(1)之后的概率,那么答案就是(sum_{i=0}^n(-1)^ig_i)

考虑一下(g_i)怎么算,发现又不会算了

考虑一个问题,就是现在有两个人(i,j),其中(i)(j)先死的概率是多少

这个概率是(frac{w_i}{w_i+w_j}),瞎猜一下大概是这个样子的

我们设(p_i)(i)(j)先死的概率,设(S=sum_{i=1}^nw_i),那么就有

[p_i=frac{w_i}{S}+frac{S-w_i-w_j}{S}p_i ]

这样解一下方程就能得到(p_i=frac{w_i}{w_i+w_j})

我们考虑刚才那个(g_i)还是非常不好求的样子,我们先考虑对于一个集合(T),这个集合(T)里的人都在(1)之后死的概率

显然我们可以把(T)集合里的人合并成一个人,那么容斥之后的答案就是

[sum_{Tsubset S}(-1)^{|T|}frac{w_1}{w_1+sum T} ]

我们注意到(sum_{i=1}^nw_i leq 10^5),这启示我们把每一个分母出现的次数都算出来

考虑到前面还有一个容斥系数,我们可以把每个人都写成一个生成函数(1-x^{w_i})

分治(ntt)求一下(prod_{i=1}^n1-x^{w_i})就能算每一个分母出现的次数了

代码

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define pb push_back
#define re register
#define LL long long
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
const int mod = 998244353;
const int maxn = 262144 + 5;
const int G[2] = { 3, (mod + 1) / 3 };
inline int read() {
    char c = getchar();
    int x = 0;
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - 48, c = getchar();
    return x;
}
std::vector<int> q[maxn * 3];
int n, len, ra[maxn], la[maxn], inv[maxn], a[maxn], rev[maxn], pre[maxn];
int __og[25][2];
inline int ksm(int a, int b) {
    int S = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1)
            S = 1ll * S * a % mod;
    return S;
}
inline void NTT(int *f, int o) {
    for (re int i = 0; i < len; i++)
        if (i < rev[i])
            std::swap(f[i], f[rev[i]]);
    for (re int w = 0, i = 2; i <= len; i <<= 1, w++) {
        int ln = i >> 1;
        int og1;
        if (!__og[w][o])
            og1 = __og[w][o] = ksm(G[o], (mod - 1) / i);
        else
            og1 = __og[w][o];
        for (re int t, og = 1, l = 0; l < len; l += i, og = 1)
            for (re int x = l; x < l + ln; ++x) {
                t = 1ll * f[x + ln] * og % mod, og = 1ll * og * og1 % mod;
                f[x + ln] = (f[x] - t + mod) % mod, f[x] = (f[x] + t) % mod;
            }
    }
    if (!o)
        return;
    int Inv = ksm(len, mod - 2);
    for (re int i = 0; i < len; i++) f[i] = 1ll * f[i] * Inv % mod;
}
void cdq(int l, int r, int t) {
    if (l == r) {
        q[t].pb(1);
        for (re int i = 1; i < a[l]; i++) q[t].pb(0);
        q[t].pb(mod - 1);
        return;
    }
    int mid = l + r >> 1;
    cdq(l, mid, t << 1), cdq(mid + 1, r, t << 1 | 1);
    len = 1;
    while (len <= pre[r] - pre[l - 1]) len <<= 1;
    for (re int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | ((i & 1) ? len >> 1 : 0);
    for (re int i = 0; i < q[t << 1].size(); i++) la[i] = q[t << 1][i];
    for (re int i = q[t << 1].size(); i < len; i++) la[i] = 0;
    for (re int i = 0; i < q[t << 1 | 1].size(); i++) ra[i] = q[t << 1 | 1][i];
    for (re int i = q[t << 1 | 1].size(); i < len; i++) ra[i] = 0;
    NTT(la, 0), NTT(ra, 0);
    for (re int i = 0; i < len; i++) la[i] = 1ll * la[i] * ra[i] % mod;
    NTT(la, 1);
    for (re int i = 0; i <= pre[r] - pre[l - 1]; i++) q[t].pb(la[i]);
}
inline int calc(int x) { return 1ll * a[0] * inv[a[0] + x] % mod; }
int main() {
    n = read();
    inv[1] = 1;
    for (re int i = 0; i < n; i++) a[i] = read();
    pre[0] = a[0];
    for (re int i = 1; i < n; i++) pre[i] = pre[i - 1] + a[i];
    cdq(1, n - 1, 1);
    for (re int i = 2; i <= pre[n - 1]; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    int ans = 0;
    for (re int i = 0; i < q[1].size(); i++) ans = (ans + 1ll * q[1][i] * calc(i) % mod) % mod;
    printf("%d
", ans);
    return 0;
}
原文地址:https://www.cnblogs.com/asuldb/p/11055120.html