BZOJ3583 杰杰的女性朋友 矩阵

原文链接https://www.cnblogs.com/zhouzhendong/p/BZOJ3583.html

题目传送门 - BZOJ3583

题意

  有一个 $n$ 个点构成的有向图。

  对于每一个点 $i$ ,给定两组参数,每组参数分别有 $k$ 个值。这两组参数分别记做: $in[i][1cdots k],out[i][1cdots k]$ 。

  从点 $i$ 连到点 $j$ 的边数定义为 $sum_{t=1}^k in[i][t] imes out[i][t]$ 。

  $m$ 组询问,每次询问从 点 $x$ 走到点 $y$ ,经过不超过 $d$ 条边的方案总数。

  $nleq 1000,mleq 50,kleq 20,dleq 2^{31}-1$

题解

  首先我们选取 $k$ 个中介点,很容易得到一个由原图的 $n$ 个点转移到这 $k$ 个点的方案数的转移矩阵 $mathbf O$;类似的,可以得一个由 $k$ 个中介点转移到原图的 $n$ 个点的方案数的转移矩阵 $mathbf I$ 。

  假设我们要求的是恰好经过 $d$ 条路径的方案数,那么,显然,我们只需要求出 $(mathbf{OI})^d$ 的第 $i$ 行第 $j$ 列的值即可。

  但是我们发现这个矩阵是 $1000 imes 1000$ 的,复杂度显然不行。我们发现 $k$ 非常小,而且矩阵 $mathbf {IO}$ 的长宽都是 $k$ 。

  由于矩阵乘法具有结合律,所以我们可以把原式写成:

  $mathbf{O} (mathbf{IO})^d mathbf{I}$ 这样时间复杂度就对了。

  但是原题要求的是不超过 $d$ 步的。

  考虑新增一个点,这个点只能走到自己,将询问中的终点连向它即可。

代码

#pragma GCC optimize("O2")
#include <bits/stdc++.h>
using namespace std;
const int N=1005,K=25,mod=1e9+7;
struct Mat{
	int r,c;
	vector <vector <int> > v;
	Mat(){}
	Mat(int _r,int _c,int x){
		r=_r,c=_c;
		vector <int> vec;
		vec.clear();
		for (int i=0;i<=c;i++)
			vec.push_back(0);
		v.clear();
		for (int i=0;i<=r;i++)
			v.push_back(vec);
		if (r==c)
			for (int i=0;i<=r;i++)
				v[i][i]=x;
	}
	void Print(){
		for (int i=0;i<=r;i++,puts(""))
			for (int j=0;j<=c;j++)
				printf("%3d ",v[i][j]);
		puts("");
	}
};
Mat operator * (Mat A,Mat B){
	Mat C(A.r,B.c,0);
	if (A.c!=B.r)
		return C;
	for (int i=0;i<=A.r;i++)
		for (int j=0;j<=B.c;j++)
			for (int k=0;k<=A.c;k++)
				C.v[i][j]=(1LL*A.v[i][k]*B.v[k][j]+C.v[i][j])%mod;
	return C;
}
Mat Pow(Mat x,int y){
	Mat ans(x.r,x.c,1);
	for (;y;y>>=1,x=x*x)
		if (y&1)
			ans=ans*x;
	return ans;
}
int read(){
	int x=0;
	char ch=getchar();
	while (!isdigit(ch))
		ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x;
}
int n,m,k;
Mat I,O,M,res;
int main(){
	n=read(),k=read();
	O=Mat(n,k,0);
	I=Mat(k,n,0);
	for (int i=1;i<=n;i++){
		for (int j=1;j<=k;j++)
			O.v[i][j]=read();
		for (int j=1;j<=k;j++)
			I.v[j][i]=read();
	}
	I.v[0][0]=O.v[0][0]=1;
	m=read();
	while (m--){
		int x=read(),y=read(),d=read();
		O.v[y][0]=1;
		res=O*Pow(I*O,d);
		int ans=0;
		for (int i=0;i<=k;i++)
			ans=(1LL*res.v[x][i]*I.v[i][0]+ans)%mod;
		printf("%d
",ans);
		O.v[y][0]=0;
	}
	return 0;
}

  

原文地址:https://www.cnblogs.com/zhouzhendong/p/BZOJ3583.html