AT5202 [AGC038E] Gachapon(min-max)

AT5202 [AGC038E] Gachapon(min-max)

题目大意

有一个随机数生成器,生成 ([0,n-1]) 之间的整数,其中生成 (i) 的概率为 (frac{A_i}{S}),其中,(S=sum A_i)

这个随机数生成器不断生成随机数,当 (forall iin[0,n-1])(i) 至少出现了 (B_i) 次时,停止生成,否则继续生成。

求期望生成随机数的次数,输出答案对 (998244353) 取模的结果。

数据范围

(A_i,B_igeq 1)(sum A_i,sum B_i,nleq 400)

解题思路

显然是一个 min-max 反演

[Ans = sum_{T subseteq S}(-1)^{|T|+1}frac {S}{sum_{iin T}A_i}f(T) ]

其中,(f(T)) 表示 T 集合中第一个至少出现了 (B_i) 次的期望次数。

考虑暴力求 T 集合的答案

[f(T) = sum_{i=1}P(x=i) imes i=sum_{i=0}^{sumB}P(x > i) ]

如何求出 (P(x>i)) 呢?考虑用方案数除以总方案数,方案数就是一个背包问题,用生成函数表示是

[f(T)=sum_{i=0}left[frac {x^i}{i!} ight]left(prod_{jin T}sum_{t=0}^{B_j-1}A_j^tfrac {x^t}{t!} ight) ]

容易发现我们只用一维即可,时间复杂度是 (Theta(n^2))

观察发现只要选中的生成函数不会变,而且前面的 ((-1)^{|T|+1}) 可以乘进去,又发现 (S) 很小,我们用另一维状态去压缩它即可,时间复杂度 (Theta(n^3)),最后统计答案即可。


/*
      />  フ
      |  _  _|
      /`ミ _x 彡
      /      |
     /   ヽ   ?
  / ̄|   | | |
  | ( ̄ヽ__ヽ_)_)
  \二つ
  */

#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;

template <typename T>
void read(T &x) {
    x = 0; bool f = 0;
    char c = getchar();
    for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
    for (;isdigit(c);c=getchar()) x=x*10+(c^48);
    if (f) x=-x;
}

template<typename F>
inline void write(F x, char ed = '
') {
    static short st[30];short tp=0;
    if(x<0) putchar('-'),x=-x;
    do st[++tp]=x%10,x/=10; while(x);
    while(tp) putchar('0'|st[tp--]);
    putchar(ed);
}

template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }

template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }

const int P = 998244353;
const int N = 405;
ll inv[N], fac[N], a[N], b[N], A, B, ans, n;
ll fpw(ll x, ll mi) {
    ll res = 1;
    for (; mi; mi >>= 1, x = x * x % P)
        if (mi & 1) res = res * x % P;
    return res;
}
ll g[N][N], f[N][N];
int main() {
    read(n);
    inv[0] = fac[0] = inv[1] = fac[1] = 1;
    for (int i = 2;i <= 400; i++) inv[i] = (P - P / i) * inv[P % i] % P;
    for (int i = 2;i <= 400; i++)
        inv[i] = inv[i-1] * inv[i] % P,
        fac[i] = fac[i-1] * i % P;

    for (int i = 1;i <= n; i++) {
        read(a[i]), read(b[i]);
        A += a[i], B += b[i];
    }
    f[0][0] = -1;
    /* for (int i = 1;i <= 50; i++) write(inv[i], ' '), write(fac[i]); */
    for (int i = 1;i <= n; i++) {
        memcpy(g, f, sizeof(g));
        for (int s = A;s >= 0; s--) {
            for (int j = B;j >= 0; j--) {
                if (s < a[i]) { f[s][j] = 0; continue; }
                ll t = a[i];
                f[s][j] = f[s-a[i]][j];
                for (int k = 1;k < b[i]; k++, t = t * a[i] % P) 
                    f[s][j] = (f[s][j] + t * inv[k] % P * f[s-a[i]][j-k]) % P;
            }
        }
        for (int j = 0;j <= A; j++)
            for (int k = 0;k <= B; k++)
                f[j][k] = (g[j][k] - f[j][k] + P) % P;
    }
    for (int s = 1;s <= A; s++) {
        ll tt = fpw(s, P - 2), t = A * tt % P;
        ll res = 0;
        for (int i = 0;i <= B; i++, t = t * tt % P) 
            res = (res + f[s][i] * fac[i] % P * t) % P;
        res %= P, ans = (ans + res) % P;
    }
    write(ans);
    return 0;
}

*/
原文地址:https://www.cnblogs.com/Hs-black/p/13687231.html