P4457[BJOI2018]治疗之雨【期望dp,高斯消元】

正题

题目链接:https://www.luogu.com.cn/problem/P4457


题目大意

开始一个人最大生命值为\(n\),剩余\(hp\)点生命,然后每个时刻如果生命值没有满那么有\(\frac{1}{m+1}\)的概率回复一点生命,然后敌人攻击\(k\)次,每次有\(\frac{1}{m+1}\)概率造成一点伤害。

求期望多少次后生命值降到\(0\)或以下。

\(1\leq T\leq 100,1\leq n\leq 1500,1\leq m,k\leq 10^9\)


解题思路

\(dp\)方程还是很好推的,设\(p_i\)表示在敌人攻击时受到\(i\)点伤害的概率,那么就是

\[p_i=(\frac{1}{m+1})^i(\frac{m}{m+1})^{k-i}\binom{k}{i} \]

的概率,这个\(i\)只需要计算到\(n\)就好了。

然后设\(f_i\)表示剩余\(i\)点生命时期望还需要打多久
然后枚举一个\(j\)表示本回合受到的伤害,分成回复了生命或者没有回复生命两种情况,方程就是

\[f_i=\frac{1}{m+1}(\sum_{j=0}^{i}p_jf_{i-j+1})+\frac{m}{m+1}(\sum_{j=0}^{i-1}p_jf_{i-j}+1) \]

当然\(f_n\)需要特殊处理

\[f_n=\sum_{i=0}^np_if_{n-i}+1 \]

发现这个方程是有前有后的环状转移,但是暴力高斯消元\(O(n^3)\)的时间复杂度接受不了。

不难发现的是我们现在的方程矩阵其实就是一个下三角矩阵再往右扩一列。我们可以先\(O(n^2)\)把这个下三角消成对角线然后第\(i\)列就只有\(i\)\(i+1\)两个系数了,反过来再消一次就好了。

时间复杂度\(O(Tn^2)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=1600,P=1e9+7;
ll T,n,hp,m,k,a[N][N],b[N];
ll inv[N],p[N];
ll power(ll x,ll b){
	ll ans=1;
	while(b){
		if(b&1)ans=ans*x%P;
		x=x*x%P;b>>=1;
	}
	return ans;
}
ll C(ll n,ll m){
	ll ans=1;
	for(ll i=n-m+1;i<=n;i++)ans=ans*i%P;
	return ans*inv[m]%P;
}
signed main()
{
	inv[1]=1;
	for(ll i=2;i<N;i++)
		inv[i]=P-(P/i)*inv[P%i]%P;
	inv[0]=1;
	for(ll i=1;i<N;i++)
		inv[i]=inv[i-1]*inv[i]%P;
	scanf("%lld",&T);
	while(T--){
		scanf("%lld%lld%lld%lld",&n,&hp,&m,&k);
		ll invm=power(m+1,P-2);
		if(!k||k==1&&!m){puts("-1");continue;}
		else if(!m){
			ll ans=0;
			while(hp>0){if(hp<n)hp++;hp-=k;ans++;}
			printf("%lld\n",ans);continue;
		}
		ll tmp=power(invm,k);
		for(ll i=0;i<=min(k,n);i++)
			p[i]=tmp*power(m,k-i)%P*C(k,i)%P;
		memset(a,0,sizeof(a));
		memset(b,0,sizeof(b));
		for(ll i=0;i<=min(k,n-1);i++)
			a[n][n-i]=P-p[i];
		a[n][n]++;b[n]++;
		for(ll i=1;i<n;i++){
			a[i][i]=1;b[i]=1;
			for(ll j=0;j<=min(k,i-1);j++)
				(a[i][i-j]+=P-invm*m%P*p[j]%P)%=P;
			for(ll j=0;j<=min(k,i);j++)
				(a[i][i-j+1]+=P-invm*p[j]%P)%=P;
		}
		for(ll i=1;i<=n;i++){
			ll inv=power(a[i][i],P-2);
			a[i][i]=1;b[i]=b[i]*inv%P;
			a[i][i+1]=a[i][i+1]*inv%P;
			for(ll j=i+1;j<=n;j++){
				ll rate=P-a[j][i];a[j][i]=0;
				(a[j][i+1]+=a[i][i+1]*rate)%=P;
				(b[j]+=b[i]*rate)%=P;
			}
		}
		for(ll i=n-1;i>=1;i--){
			ll rate=P-a[i][i+1];
			b[i]=(b[i]+rate*b[i+1])%P;
		}
		printf("%lld\n",(b[hp]+P)%P);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14457018.html