矩阵乘法与邻接矩阵

问题

设现在有一个邻接矩阵 (A) ,那么 (A^p) 表示什么。

相信大家都搜了一下,发现 (A[i][j]) 表示从 (i)(j) 经过了 (p) 步的总方案数,但是原理却不一定明白,所以这篇文章主要想从另一个思路证明上面的结论。

证明

我们先抛开矩阵乘法,考虑单纯用 (DP) 这道题目怎样做。

我们可以设 (DP[i][j][p]) 表示从 (i) 点出发,经过了 (p) 步到达 (j) 点的方案数。

初始化因为 (p=0) 的初始化没有什么意义,所以我们直接看 (p=1) 的初始化。

(p=1) 的时候的初始化其实就是这张图的邻接矩阵,因为邻接矩阵就是原图,原图中间的两个点就是通过一条边相连。

我们考虑转移

[DP[i][j][p]=sum_{k=1}^n DP[k][j][p-1]*DP[i][k][1] ]

其实就是我们考虑枚举全部的点在 (p-1) 的方案数,然后乘 ,

因为第三维只是和 (p-1) 有关,所以第三维可以省掉,然后和朴素的矩阵幂比较,我们发现他们长的一模一样,其实都是在求邻接矩阵的 (p) 次方,所以我们就可以使用矩阵快速幂来直接求最终的矩阵了,上面的问题也得到了证明。

典型例题

P3758 [TJOI2017]可乐

这道题目这篇博客已经讲得很明白了,不想再抄一遍了。emmm

#include<bits/stdc++.h>
#define ll long long
struct Mat{
	int size;
	ll **M=NULL;
	inline ll Start()
	{
		if (M!=NULL) return M[1][1];
		else return LLONG_MIN;
	}
	inline void Clear(int sz)
	{
		if (M==NULL) {New(sz);return ;}
		for (int i=0;i<=sz;i++)
		for (int j=0;j<=sz;j++)
		M[i][j]=0;
		return ;
	}
	inline void New(int sz)
	{
		if (M!=NULL)
		{
			printf("
RE
");
			printf("This matrix has been used!
");
			return ;
		}
		size=sz;
		M=new ll*[sz+10];
		for (int i=0;i<sz+10;i++)
		M[i]=new ll[sz+10];
		Clear(sz);
		return ;
	}
	inline void Build(int sz)
	{
		size=sz;
		if (M==NULL) New(sz);
		for (int i=1;i<=sz;i++) M[i][i]=1;
		return ;
	}
	inline void Init(ll now[],int sz)
	{
		if (M==NULL) New(sz);
		int num=0;
		for (int i=1;i<=sz;i++)
		for (int j=1;j<=sz;j++)
		M[i][j]=now[++num];
		return ;
	}
	inline void Out()
	{
		if (M!=NULL)
		for (int i=1;i<=size;i++)
		{
			for (int j=1;j<=size;j++)
			printf("%lld ",M[i][j]);
			printf("
");
		}
		return ;
	}
	inline void Delete()
	{
		for (int i=0;i<size+10;i++)
		delete []M[i];
		delete []M;
		M=NULL;
		size=0;
		return ;
	}
};
ll mod;
inline Mat operator * (Mat a,Mat b)
{
	Mat c;
	c.Clear(a.size);
	for (int i=1;i<=c.size;i++)
	for (int j=1;j<=c.size;j++)
	for (int k=1;k<=c.size;k++)
	c.M[i][j]=(c.M[i][j]+a.M[i][k]*b.M[k][j]%mod)%mod;
	return c;
}
inline Mat Mat_qpow(Mat a,ll p)
{
	Mat ans,base;
	base.New(a.size);
	ans.Build(a.size);
	for (base=a;p;p>>=1,base=base*base)
	if (p&1) ans=ans*base;
	return ans;
}
int main()
{
	mod=2017;
	int n,m,k;
	scanf("%d%d",&n,&m);
	Mat now,ans;
	now.New(n+1);ans.New(n+1);
	for (int i=1;i<=m;i++)
	{
		int s,e;
		scanf("%d%d",&s,&e);
		now.M[s][e]=now.M[e][s]=1;
	}
	scanf("%d",&k);
	for (int i=1;i<=n;i++)
	{
		now.M[i][n+1]=1;
		now.M[i][i]=1;
	}
	now.M[n+1][n+1]=1;
	ans=Mat_qpow(now,k);
	ll Ans=0;
	for (int i=1;i<=n+1;i++)
	Ans=(Ans+ans.M[1][i])%mod;
	printf("%lld
",Ans);
	ans.Delete();now.Delete();
	return 0;
}

上面的证明有一些问题,以后再补锅

原文地址:https://www.cnblogs.com/last-diary/p/11732999.html