CF446DDZY Loves Games【高斯消元,矩阵乘法】

正题

题目链接:https://www.luogu.com.cn/problem/CF446D


题目大意

给出\(n\)个点\(m\)条边的一张无向图,一些点有陷阱,走到时会损失一条生命,总共有\(k\)条生命,求从\(1\)出发随机游走到\(n\)没有死亡且到终点时仅剩一条命的概率。

\(1\leq n\leq 500,1\leq m\leq 10^5,2\leq k\leq 10^9\)

陷阱点个数不超过\(100\)


解题思路

这个\(k\)很大,这个陷阱点个数又很少,我们可以考虑矩阵乘法,预处理\(a_{i,j}\)表示陷阱点\(i\)走到陷阱点\(j\)且中间没有走陷阱点的概率,然后矩阵乘法转移即可。

但是现在的问题是我们如何快速预处理出\(a_{i,j}\),可以考虑枚举终点\(x\)那么有\(f_x=1\),然后其他的陷阱点处\(f_x=0\),一般的点处\(f_x=\frac{1}{deg_x}\sum_{x\rightarrow y}f_{y}\),这样我们就能对于每个起点预处理出\(g_{i,j}\)表示从无陷阱的节点\(i\)走到陷阱点\(j\)且中间没有其他陷阱点的概率。

之后我们枚举起点陷阱点的出边就可以预处理出\(a\)了,因为上面的过程要用到高斯消元,所以这样的复杂度是\(O(n^4)\)的,无法通过本题。

不难注意到上面的消元中,我们只有陷阱点处的常数(且陷阱点处仅有常数)发生了变化,所以我们可以直接高斯消出每个非陷阱点和所有陷阱点的关系式,然后直接带入常数即可。

记陷阱点个数为\(r\),时间复杂度\(O((n+r)^3+r^3\log k)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
const int N=610,S=110;
const double eps=1e-8;
struct matrix{
	double a[S][S];
}ans,m,c;
matrix operator*(const matrix &a,const matrix &b){
	memset(c.a,0,sizeof(c.a));
	for(int i=0;i<S;i++)
		for(int j=0;j<S;j++)
			for(int k=0;k<S;k++)
				c.a[i][j]+=a.a[i][k]*b.a[k][j];
	return c;
}
int n,h,k,deg[N],a[N][N];
double f[N][N];bool v[N];
vector<int> q;
int main()
{
	scanf("%d%d%d",&n,&h,&k);
	for(int i=1;i<=n;i++){
		scanf("%d",&v[i]);
		if(v[i]){q.push_back(i);f[i][n+q.size()]=f[i][i]=1;}
	}
	int r=n+q.size();
	for(int i=1,x,y;i<=h;i++){
		scanf("%d%d",&x,&y);
		a[x][y]++;a[y][x]++;
		deg[x]++;deg[y]++;
	}
	for(int i=1;i<=n;i++){
		if(v[i])continue;
		for(int j=1;j<=n;j++)
			f[i][j]=-1.0*a[i][j]/(double)deg[i];
		f[i][i]=1;
	}
	for(int i=1;i<=n;i++){
		for(int j=i;j<=n;j++)
			if(fabs(f[i][j])>eps){
				swap(f[i],f[j]);
				break;
			}
		double d=f[i][i];
		for(int j=i;j<=r;j++)
			f[i][j]=f[i][j]/d;
		for(int j=1;j<=n;j++){
			if(i==j)continue;
			double rate=-f[j][i]/f[i][i];
			for(int k=i;k<=r;k++)
				f[j][k]+=rate*f[i][k];
		}
	}
	for(int i=0;i<q.size();i++)
		ans.a[0][i]=f[1][n+i+1];
	for(int i=0;i<q.size();i++){
		for(int j=0;j<q.size();j++){
			for(int k=1;k<=n;k++)
				m.a[i][j]+=a[q[i]][k]*f[k][n+j+1];
			m.a[i][j]/=(double)deg[q[i]];
		}
	}
	k-=2;
	while(k){
		if(k&1)ans=ans*m;
		m=m*m;k>>=1;
	}
	printf("%.10lf\n",ans.a[0][q.size()-1]);
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/15683917.html