【SDOI2017】硬币游戏(概率dp,高斯消元)

很神奇的一道题目。

首先先举一个例子,等会结合着讲:只有两个人猜,猜的串分别是 A = TTH A= exttt{TTH} A=TTH B = HTT B= exttt{HTT} B=HTT

设所有人猜的序列为 s 1 , s 2 , ⋯   , s n s_1,s_2,cdots,s_n s1,s2,,sn

首先对于这种可能存在无限情况的题目,我们要学会归类:

把所有可能的硬币序列(可能有无限种)分成两类:未终止状态和已终止状态。

其中未终止状态表示:当前的硬币序列(设为 T T T)还没有胜者。即不存在任意一个 s i s_i si,使得 s i s_i si T T T 中出现过。例如在上述例子中, TT exttt{TT} TT HHT exttt{HHT} HHT HHH... ⏟ 无限个 underbrace{ exttt{HHH...}}_{ ext{无限个}} 无限个 HHH... 都是未终止状态。

终止状态表示:当进行到当前的硬币序列(设为 T T T)时,决出了胜者。即存在一个 i i i 使得 s i s_i si T T T 的后缀,且不存在任意的 j ≠ i j eq i j=i,使得 s j s_j sj T T T 中出现过。例如在上述例子中, TTH exttt{TTH} TTH HH...H ⏟ 若干个 TT underbrace{ exttt{HH...H}}_{ ext{若干个}} exttt{TT} 若干个 HH...HTT 都是终止状态。

可以发现未终止状态和终止状态并不能包含所有由 H exttt{H} H T exttt{T} T 构造出来的字符串。例如在上述例子中, TTHT exttt{TTHT} TTHT THTHTHTTH exttt{THTHTHTTH} THTHTHTTH 都是不可达状态。发现不可达状态其实就是在某个终止状态后加上任意非空字符串构造出来的。

然后我们要找出这些情况之间的关系,然后用 dp 或解方程来求解。

如何建立未终止状态和已终止状态之间的关系?

N N N 是任意一个未终止状态。

不妨 N N N 后面加上某一个 s i s_i si,就能达到一个终止状态了。

但是你发现,在一些特殊情况下,这个 N + s i N+s_i N+si 可能是一个不可达状态。

什么意思呢?那我们刚刚那个样例来分析:

对于 N + TTH N+ exttt{TTH} N+TTH,它可能是两种状态:

  1. 它是一个以 A = TTH A= exttt{TTH} A=TTH 为终止的终止状态,即 N + A N+A N+A

  2. 它是一个以 B = HTT B= exttt{HTT} B=HTT 为终止的终止状态后面再接上一个字符串达到的不可达状态,即 N ′ + B + H N'+B+ exttt{H} N+B+H,其中可知 N ′ N' N 同样是一个未终止状态,且满足 N ′ + H = N N'+ exttt{H}=N N+H=N

也就是说, 1 8 N = P ( A ) + 1 2 P ( B ) dfrac{1}{8}N=P(A)+dfrac{1}{2}P(B) 81N=P(A)+21P(B)。(其中 P ( A ) P(A) P(A) A A A 获胜的概率, P ( B ) P(B) P(B) 同理)

按照类似的思路,我们可以两两枚举每个串,比较他们的前缀和后缀,如果相同就加上一个系数。

形式化地,设 p r e i , j pre_{i,j} prei,j 表示 s i s_i si 的长度为 j j j 的前缀 s u f i , j suf_{i,j} sufi,j 表示 s i s_i si 的长度为 j j j 的后缀, x i x_i xi 表示以 s i s_i si 结尾的终止状态出现的概率(即 i i i 获胜的概率)。

对于每一个 i i i,我们可以列出如下方程:

1 2 m N = ∑ j = 1 n x j ∑ k = 1 m [ p r e i , k = s u f j , k ] 1 2 m − k dfrac{1}{2^m}N=sum_{j=1}^nx_jsum_{k=1}^m[pre_{i,k}=suf_{j,k}]dfrac{1}{2^{m-k}} 2m1N=j=1nxjk=1m[prei,k=sufj,k]2mk1

(对于中间那个判断我们可以用哈希处理)

然后还有一条方程: ∑ i = 1 n x i = 1 sumlimits_{i=1}^nx_i=1 i=1nxi=1

所以一共有 n + 1 n+1 n+1 个未知数 n + 1 n+1 n+1 个方程,就能用高斯消元解出每一个 x i x_i xi 了。

代码如下:

#include<bits/stdc++.h>

#define N 310

using namespace std;

const unsigned int base=19260817;

int n,m;
char s[N][N];
unsigned int poww[N],sum[N][N];
double div2[N],x[N],a[N][N];

unsigned int get(int i,int l,int r)
{
	if(l>r) return 0;
	return sum[i][r]-sum[i][l-1]*poww[r-l+1];
}

void Gauss()
{
	for(int i=1;i<=n+1;i++)
	{
        for(int j=i;j<=n+1;j++)
		{
            if(!a[j][i]) continue;
            for(int k=i;k<=n+2;k++) 
				swap(a[i][k],a[j][k]);
            break;
        }
        for(int j=i+1;j<=n+2;j++)
		{
            if(!a[j][i]) continue;
            double tmp=a[j][i]/a[i][i];
            for(int k=i;k<=n+2;k++)
				a[j][k]-=a[i][k]*tmp;
        }
    }
    x[n+1]=a[n+1][n+2]/a[n+1][n+1];
    for(int i=n;i>=1;i--){
        for(int j=i+1;j<=n+1;j++)
			x[i]-=a[i][j]*x[j];
        x[i]/=a[i][i];
    }
}

int main()
{
	scanf("%d%d",&n,&m);
	poww[0]=div2[0]=1;
	for(int i=1;i<=m;i++)
		poww[i]=poww[i-1]*base,div2[i]=div2[i-1]/2.0;
	for(int i=1;i<=n;i++)
	{
		scanf("%s",s[i]+1);
		for(int j=1;j<=m;j++)
			sum[i][j]=sum[i][j-1]*base+s[i][j];
	}
	for(int i=1;i<=n;i++)
	{
		for(int j=1;j<=n;j++)
		{
			double ans=0;
			for(int k=1;k<=m;k++)
				if(get(i,1,k)==get(j,m-k+1,m))
					ans+=div2[m-k];
			a[i][j]=ans;
		}
		a[i][n+1]=-div2[m];
	}
	for(int i=1;i<=n;i++)
		a[n+1][i]=1;
	a[n+1][n+2]=1;
	Gauss();
	for(int i=1;i<=n;i++)
		printf("%.10lf
",x[i]);
	return 0;
}
/*
2 2
TTH
HTT
*/
原文地址:https://www.cnblogs.com/ez-lcw/p/14448637.html