洛谷 P5326 [ZJOI2019]开关

洛谷 P5326 [ZJOI2019]开关

https://www.luogu.com.cn/problem/P5326

Snipaste_2020-06-30_18-18-35.png

Snipaste_2020-06-30_18-18-29.png

Snipaste_2020-06-30_18-18-44.png

Tutorial

https://www.luogu.com.cn/blog/xht37/solution-p5326

https://www.cnblogs.com/PinkRabbit/p/ZJOI2019D2T1.html

(p_i=dfrac {p_i}{sum p_i})

(f(x))表示在第(k)步到达合法状态的概率的生成函数,因为只关心第一次到达合法状态的情况,所以设(g(x))表示走(k)步后回到原来的状态的概率,(h(x))表示第(k)步第一次走到合法状态的概率,则有(f(x)=g(x)h(x) o h(x)=dfrac{f(x)}{g(x)}) .设(h(x)=sum a_k x^k),则我们要求就是

[sum ka_k=h'(1)=dfrac{f'(1)g(1)-f(1)g'(1)}{g^2(1)} ]

考虑如何求(f(x)).到达合法状态的条件为选择开关(i)的次数与(s_i)相等.则有

[F_i(x)=dfrac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}2 ]

发现(f(x))是OGF,(F_i(x))为EGF,为了相互转化,将(prod F_i(x))表示为(sum c_k(e^x)^k)的形式,其中(c_k)可以用背包在(O(nsum p))的时间求得,最后得到

[egin{align} f(x)&=sum_k ([x^k]k!sum_i c_i(e^x)^i)x^k \ &=sum_k(k!sum_i c_i [x^k](e^x)^i)x^k \ &=sum_k(k!sum_ic_idfrac{i^k}{k!})x^k \ &=sum_k (sum_i c_ii^k)x^k \ &=sum_ic_isum_{k}i^kx^k \ &=sum_idfrac{c_i}{1-ix} end{align} ]

(g(x))的处理类似,最后得到

[g(x)=sum_idfrac{d_i}{1-ix} ]

但是发现当(i=1)时会有(1-x)这一项,所以不能直接将(x=1)带入,考虑分子分母同乘((1-x)),得到新的(f(x),g(x))

[f(x)=sum_idfrac{c_i(1-x)}{1-ix}=c_1+sum_{i ot=1}dfrac{c_i(1-x)}{1-ix} ]

所以此时(f(1)=c_1)

[f'(x)=sum_{i ot=1}dfrac{c_i(ix-1)+ic_i(1-x)}{(1-ix)^2} \ f'(1)=sum_{i ot=1}dfrac{c_i(i-1)}{(1-i)^2}=sum_{i ot=1}dfrac{c_i}{i-1} ]

(g(1),g'(1))也类似计算,即可得到答案.

Code

#include <cstdio>
#include <cstring>
#include <iostream>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a) power(a,mod-2)
using namespace std;
inline char gc() {
//	return getchar();
	static char buf[100000],*l=buf,*r=buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
	x=0; int f=1,ch=gc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
	x*=f;
}
typedef long long ll;
const int mod=998244353,r2=(mod+1)>>1;
const int maxn=100+5,maxP=1e5+50;
int n,P,s[maxn],p[maxn];
int c[2][maxP],d[2][maxP];
inline int sub(int x) {return x<0?x+mod:x;}
ll power(ll x,ll y) {
	ll re=1;
	while(y) {
		if(y&1) re=re*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return re;
}
inline int sqr(int x) {return (ll)x*x%mod;}
inline void upd(int *a,int *b,int v,int w) {
	for(int i=0;i<=(P<<1);++i) if(b[i]) {
		a[i+w]=(a[i+w]+(ll)v*b[i])%mod;
	}
}
int main() {
	rd(n);
	for(int i=1;i<=n;++i) rd(s[i]);
	for(int i=1;i<=n;++i) rd(p[i]),P+=p[i];
	int cur=0;
	c[cur][P]=d[cur][P]=1;
	for(int i=1;i<=n;++i) {
		cur^=1;
		memset(c[cur],0,sizeof(c[cur])),memset(d[cur],0,sizeof(d[cur]));
		upd(c[cur],c[cur^1],r2,p[i]),upd(c[cur],c[cur^1],(ll)r2*(s[i]==1?mod-1:1)%mod,-p[i]);
		upd(d[cur],d[cur^1],r2,p[i]),upd(d[cur],d[cur^1],r2,-p[i]);
	}
	int an=0,c1=c[cur][P<<1],d1=d[cur][P<<1],t=inver(P);
	for(int i=-P;i<P;++i) {
		an=(an+inver(sub((ll)i*t%mod-1))*sub((ll)c[cur][i+P]*d1%mod-(ll)c1*d[cur][i+P]%mod))%mod;
	}
	an=(ll)an*sqr(inver(d1))%mod;
	printf("%d
",an);
	return 0;
}
原文地址:https://www.cnblogs.com/ljzalc1022/p/13215463.html