LOJ2541 PKUWC2018 猎人杀 期望、容斥、生成函数、分治FFT

传送门


首先,每一次有一个猎人死亡之后(sum w)会变化,计算起来很麻烦,所以考虑在某一个猎人死亡之后给其打上标记,仍然计算他的(w),只是如果打中了一个打上了标记的人就重新选择。这样对应于每一个人的概率仍然是一样的,而(sum w)在计算的过程中不会变。

因为要求最后死的概率,似乎不是很好求,考虑容斥。枚举一个集合(S),我们强制集合(S)中的猎人在(1)号猎人死亡之后死亡。设集合(S)中所有猎人的(w)之和为(A),所有猎人的(w)之和为(sum),那么集合(S)能够产生的贡献为((-1) ^ {|S|} imes frac{w_1}{sum} imes sumlimits_{i=0} ^ {infty} (1 - frac{A + w_1}{sum})^i)

注意到后面是一个无穷递减等比数列,那么(sumlimits_{i=0} ^ {infty} (1 - frac{A + w_1}{sum})^i = frac{1}{1 - (1 - frac{A + w_1}{sum})} = frac{sum}{A + w_1}),那么原式等于((-1)^{|S|} imes frac{w_1}{A + w_1})

那么我们只需要计算每一个集合的(A)就可以了。

注意到对于(A)的计算,实质是一个(01)背包。但是直接(DP)肯定复杂度爆炸,考虑生成函数求解

(i)个猎人的生成函数为(-x^{w_i} + 1)(-x^{w_i})表示选择第(i)个猎人,但是集合的贡献乘上(-1)(+1)表示不选择第(i)个猎人。然后分治(FFT)求解,我们就可以得到对于所有的(A)(frac{w_1}{A + w_1})前面的系数了。

总的复杂度为(O(n log^2n))

#include<bits/stdc++.h>
#define ll long long
#define mid ((l + r) >> 1)
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    bool f = 0;
    while(!isdigit(c)){
        if(c == '-')
            f = 1;
        c = getchar();
    }
    while(isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return f ? -a : a;
}

const int MOD = 998244353 , G = 3 , INV = 332748118 , MAXN = 2e5 + 10;
int val[MAXN] , dir[MAXN] , N , need , inv_need;
vector < int > v[MAXN];

inline int poww(ll a , int b){
    int times = 1;
    while(b){
        if(b & 1)
            times = times * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return times;
}

inline void NTT(vector < int > &arr , int type){
    while(arr.size() < need)
        arr.push_back(0);
    for(int i = 1 ; i < need ; ++i)
        if(i < dir[i])
            swap(arr[i] , arr[dir[i]]);
    for(int i = 1 ; i < need ; i <<= 1){
        int wn = poww(type == 1 ? G : INV , (MOD - 1) / (i << 1));
        for(int j = 0 ; j < need ; j += i << 1){
            ll w = 1;
            for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                arr[i + j + k] = x - y < 0 ? x - y + MOD : x - y;
            }
        }
    }
}

inline void solve(int l , int r){
    need = 1;
    while(need <= v[l].size() + v[r].size())
        need <<= 1;
    inv_need = poww(need , MOD - 2);
    for(int i = 1 ; i < need ; ++i)
        dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
    NTT(v[l] , 1);
    NTT(v[r] , 1);
    for(int i = 0 ; i < need ; ++i)
        v[l][i] = 1ll * v[l][i] * v[r][i] % MOD;
    NTT(v[l] , -1);
    for(int i = 0 ; i < need ; ++i)
        v[l][i] = 1ll * v[l][i] * inv_need % MOD;
    while(v[l][v[l].size() - 1] == 0)
        v[l].erase(--v[l].end());
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("in" , "r" , stdin);
    //freopen("out" , "w" , stdout);
#endif
    N = read();
    if(N == 1){
        puts("1");
        return 0;
    }
    for(int i = 1 ; i <= N ; ++i){
        val[i] = read();
        if(i != 1){
            v[i].push_back(1);
            while(v[i].size() < val[i])
                v[i].push_back(0);
            v[i].push_back(MOD - 1);
        }
    }
    int ans = 0;
    for(int i = 1 ; i < N ; i <<= 1)
        for(int j = 2 ; j + i <= N ; j += i << 1){
            solve(j , j + i);
            vector < int >().swap(v[j + i]);
        }
    for(int i = 0 ; i < v[2].size() ; ++i)
        ans = (ans + 1ll * poww(i + val[1] , MOD - 2) * v[2][i]) % MOD;
    cout << 1ll * ans * val[1] % MOD;
    return 0;
}
原文地址:https://www.cnblogs.com/Itst/p/10274435.html