Codeforces 1349D Bear and Biscuits

​ 设 (E_i) 为所有饼干第一次都收到 (i) 手上,且游戏在 (i) 结束的期望时间,(P_i) 为这种情况的概率,再令 (E'_i) 为所有饼干第一次收到 (i) 手上的概率, (C) 为将所有饼干从一个人手上转移到另一个不同人手上的期望时间,那么有:

[Ans = sum_{i=1}^n E_{i} \ E_i = E'_i - sum_{j=1}^n[i eq j] E_{j}-P_iC \ Ans = sum_{i=1}^n E_i =sum_{i=1}^n E'_i -(n-1)E_i -(n-1)C\ nAns = sum_{i=1}^n E_i'-(n-1)C\ p_1= frac{S-x}{S} imesfrac{n-2}{n-1} \ p_2=frac{S-x}{S} imes frac{1}{n-1} \ p_3 =frac{x}S{} \ F_0 = frac{1}{p_2}\ F_x = (frac{p_2+p_3}{p_2}-1)(frac{1}{1-p_1}+F_{x-1})+ frac{1}{1-p1}\ E'_i=sum_{k=A_i}^SF_{k}, C= sum_{k=0}^S F_k \ ]

/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include<bits/stdc++.h>
#define inf ((ll) 3e18)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
    int ch = 0, f = 0; x = 0;
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
    for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
    if(f) x = -x;
}
const int N = 500005, mod = 998244353;
int f[N], sum[N], a[N], n, S;
inline void up(int &x, int y){ 
	x = x + y >= mod ? x + y - mod : x + y; 
}
inline int Pow(int a, int b){
	int ans = 1;
	for(; b; b >>= 1, a = 1ll * a * a % mod)
		if(b & 1) ans = 1ll * ans * a % mod;
	return ans;
}
int main(){
	read(n);
	for(int i = 1; i <= n; i++) read(a[i]), S += a[i];
	f[0] = n - 1;
	int InvS = Pow(S, mod - 2), C2 = Pow(n - 1, mod - 2);
	int C1 = 1ll * (n - 2) * Pow(n - 1, mod - 2) % mod;
	for(int i = 1; i < S; i++){
		int p1 = 1ll * (S - i) * InvS % mod * C1 % mod;
		int p2 = 1ll * (S - i) * InvS % mod * C2 % mod;
		int p3 = 1ll * i * InvS % mod;
		f[i] = Pow(mod + 1 - p1, mod - 2);
		up(f[i], f[i-1]);
		int tmp = 1ll * (p2 + p3) % mod * Pow(p2, mod - 2) % mod;
		up(tmp, mod - 1);
		f[i] = 1ll * tmp * f[i] % mod;
		up(f[i], Pow(mod + 1 - p1, mod - 2));
		//cout << f[i] << endl;
		
	}
	for(int i = S - 1; i >= 0; i--) 
		sum[i] = (sum[i+1] + f[i]) % mod;
	int res = 0;
	for(int i = 1; i <= n; i++) up(res, sum[a[i]]);
	up(res, mod - 1ll * (n - 1) * sum[0] % mod);
	res = 1ll * res * Pow(n, mod - 2) % mod;
	cout << res << endl;
	return 0;
}


原文地址:https://www.cnblogs.com/mangoyang/p/12885331.html