【loj2325】「清华集训 2017」小Y和恐怖的奴隶主 概率dp+倍增+矩阵乘法

题目描述

你有一个m点生命值的奴隶主,奴隶主受伤未死且当前随从数目不超过k则再召唤一个m点生命值的奴隶主。

T次询问,每次询问如果如果对面下出一个n点攻击力的克苏恩,你的英雄期望会受到到多少伤害。

输入

输入第一行包含三个正整数 T,m,k ,T 表示询问组数,m,k 的含义见题目描述。

接下来 T 行,每行包含一个正整数 n ,表示询问进行 n 次攻击后扣减Boss的生命值点数的期望。

输出

输出共 T 行,对于每个询问输出一行一个非负整数,表示该询问的答案对 998244353 取模的结果。

样例输入

3 2 6
1
2
3

样例输出

499122177
415935148
471393168


题解

概率dp+倍增+矩阵乘法

首先需要知道本题弱化版 【bzoj4832】[Lydsy2017年4月月赛]抵制克苏恩 的概率dp写法:不维护期望,只维护概率,统计概率对答案的贡献。设 $p[i][j][k][l]$ 表示 $i$ 回合奴隶主1、2、3血剩余情况为 $j$ 、$k$ 、$l$ 的概率,那么对答案的贡献就是 $frac {p[i][j][k][l]}{j+k+l+1}$ 。

本题的 $n$ 较大,考虑矩阵乘法。先预处理出状态及转移。然后相当于一个行向量乘以n个方阵,使用快速幂。

但是经过计算可知状态数为 $sumlimits_{i=0}^kC_{i+m-1}^{m-1}$ ,加上计数器总和最大为166,每次都快速幂复杂度为 $O(T·166^3log n)$ ,会TLE。

考虑到一个行向量乘以一个方阵的时间时 $O(n^2)$ 的,因此可以倍增预处理出方阵的 $2^i$ 次幂,然后把每个矩阵依次乘到行向量上即可。

时间复杂度 $O(166^3log n+T·166^2log n)$ 

有点卡常。。。

#include <cstdio>
#include <cstring>
#define mod 998244353
typedef long long ll;
int tot = 1;
ll inv(ll x)
{
	ll ans = 1 , y = mod - 2;
	while(y)
	{
		if(y & 1) ans = ans * x % mod;
		x = x * x % mod , y >>= 1;
	}
	return ans;
}
struct data
{
	ll v[170][170];
	data() {memset(v , 0 , sizeof(v));}
	ll *operator[](int a) {return v[a];}
	data operator*(data &a)
	{
		data ans;
		int i , j , k;
		for(i = 1 ; i <= tot ; i ++ )
			for(j = 1 ; j <= tot ; j ++ )
				for(k = 1 ; k <= tot ; k ++ )
					ans[i][j] = (ans[i][j] + v[i][k] * a[k][j]) % mod;
		return ans;
	}
}A[60];
int f[9][9] , g[9][9][9];
ll ans[170] , tmp[170];
void mul(ll *A , data B)
{
	int i , j;
	memset(tmp , 0 , sizeof(tmp));
	for(i = 1 ; i <= tot ; i ++ )
		for(j = 1 ; j <= tot ; j ++ )
			tmp[j] = (tmp[j] + A[i] * B[i][j]) % mod;
	for(i = 1 ; i <= tot ; i ++ ) A[i] = tmp[i];
}
int main()
{
	int T , m , p , i , j , k;
	ll n , e;
	scanf("%d%d%d" , &T , &m , &p);
	A[0][1][1] = A[0][2][1] = A[0][2][2] = 1; 
	if(m == 1) tot = 3 , A[0][tot][1] = A[0][tot][2] = A[0][tot][3] = inv(2);
	else if(m == 2)
	{
		for(i = 0 ; i <= p ; i ++ )
			for(j = 0 ; j <= p ; j ++ )
				if(i + j <= p)
					f[i][j] = ++tot;
		for(i = 0 ; i <= p ; i ++ )
		{
			for(j = 0 ; j <= p ; j ++ )
			{
				if(i + j <= p)
				{
					e = inv(i + j + 1);
					A[0][f[i][j]][1] = A[0][f[i][j]][f[i][j]] = e;
					if(i) A[0][f[i][j]][f[i - 1][j]] = i * e % mod;
					if(j)
					{
						if(i + j < p) A[0][f[i][j]][f[i + 1][j]] = j * e % mod;
						else A[0][f[i][j]][f[i + 1][j - 1]] = j * e % mod;
					}
				}
			}
		}
	}
	else
	{
		for(i = 0 ; i <= p ; i ++ )
			for(j = 0 ; j <= p ; j ++ )
				for(k = 0 ; k <= p ; k ++ )
					if(i + j + k <= p)
						g[i][j][k] = ++tot;
		for(i = 0 ; i <= p ; i ++ )
		{
			for(j = 0 ; j <= p ; j ++ )
			{
				for(k = 0 ; k <= p ; k ++ )
				{
					if(i + j + k <= p)
					{
						e = inv(i + j + k + 1);
						A[0][g[i][j][k]][1] = A[0][g[i][j][k]][g[i][j][k]] = e;
						if(i) A[0][g[i][j][k]][g[i - 1][j][k]] = i * e % mod;
						if(j)
						{
							if(i + j + k < p) A[0][g[i][j][k]][g[i + 1][j - 1][k + 1]] = j * e % mod;
							else A[0][g[i][j][k]][g[i + 1][j - 1][k]] = j * e % mod;
						}
						if(k)
						{
							if(i + j + k < p) A[0][g[i][j][k]][g[i][j + 1][k]] = k * e % mod;
							else A[0][g[i][j][k]][g[i][j + 1][k - 1]] = k * e % mod;
						}
					}
				}
			}
		}
	}
	for(i = 1 ; i < 60 ; i ++ ) A[i] = A[i - 1] * A[i - 1];
	while(T -- )
	{
		scanf("%lld" , &n);
		memset(ans , 0 , sizeof(ans));
		ans[3] = 1;
		for(i = 0 ; i < 60 ; i ++ )
			if(n & (1ll << i))
				mul(ans , A[i]);
		printf("%lld
" , ans[1]);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/GXZlegend/p/8124621.html