【SHTSC2013】超级跳马

题目链接

题目大意

(n*m)的网格上,一只马在点((1,1)),点((i,j))可以跳到((i-1,j+k))((i,j+k))((i+1,j+k)),其中(k)是一个奇数,求跳到((n,m))的方案数。

解析

设:
(f_{i,j})表示跳到((j,i))的方案数(为了方便我换了一下(i,j)的顺序,相当于按列为阶段转移)
(a_{i,j}=sum_{k=0}f_{i,j-2k})
(b_{i,j}=sum_{k=0}f_{i,j-2k-1})

得到三者之间关系是:

(f_{i,j}=a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1})
(a_{i,j}=b_{i-1,j}+f_{i,j})
(b_{i,j}=a_{i-1,j})

仅根据这三条式子,就能够用(O(nm))的时间复杂度求出(f_{i,j})了。

(mleq 10^9),还需优化。

原来是三个状态的转移,现在我们变一下式子:
(a_{i,j}=a_{i-2,j}+a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1})

现在只剩(a)一个状态的转移了,最后求(f_{i,j}=a_{i,j}-a_{i-2,j})即可。

由于(m)很大,考虑使用矩阵乘法。

转移矩阵的构造方法很巧妙,我们把(a_{i-1,1 sim n})还有(a_{i-2,1 sim n})放在初始矩阵的第一行,其它位置全部填(0)

例如(n=3)时,初始矩阵为:
(egin{Bmatrix} a_{i-1,1} & a_{i-1,2} & a_{i-1,3} & a_{i-2,1} & a_{i-2,2} & a_{i-2,3} \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 end{Bmatrix} quad)

转移矩阵为:
(egin{Bmatrix} 1 & 1 & 0 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 1 & 0 \ 0 & 1 & 1 & 0 & 0 & 1 \ 1 & 0 & 0 & 0 & 0 & 0 \ 0 & 1 & 0 & 0 & 0 & 0 \ 0 & 0 & 1 & 0 & 0 & 0 end{Bmatrix} quad)

这样时间复杂度降为(O(n^3logm)),问题解决了。

Code

#include <cstdio>
#include <cstring>

const int N = 57, P = 30011;
int max(int a, int b) { return a > b ? a : b; }
int min(int a, int b) { return a < b ? a : b; }

int n, m;

struct matrix
{
	int num[N * 2][N * 2];
	matrix operator*(matrix a)
	{
		matrix c; memset(c.num, 0, sizeof(c.num));
		for (int i = 0; i < 2 * n; i++)
			for (int j = 0; j < 2 * n; j++)
				for (int k = 0; k < 2 * n; k++)
					c.num[i][j] = (c.num[i][j] + num[i][k] * a.num[k][j] % P) % P;
		return c;
	}
} bas, mov, ret;

int getit(int m, int n)
{
	if (m <= 0) return 0;
	memset(bas.num, 0, sizeof(bas.num));
	memset(mov.num, 0, sizeof(mov.num));
	memset(ret.num, 0, sizeof(ret.num));
	for (int j = 0; j < n; j++) for (int i = max(j - 1, 0); i <= min(j + 1, n - 1); i++) mov.num[i][j] = 1;
	for (int i = n; i < 2 * n; i++) mov.num[i][i - n] = 1;
	for (int j = n; j < 2 * n; j++) mov.num[j - n][j] = 1;
	bas.num[0][0] = 1;
	for (int i = 0; i < 2 * n; i++) ret.num[i][i] = 1;
	m--;
	while (m)
	{
		if (m & 1) ret = ret * mov;
		mov = mov * mov, m >>= 1;
	}
	bas = bas * ret;
	return bas.num[0][n - 1];
}

int main()
{
	scanf("%d%d", &n, &m);
	printf("%d
", (getit(m, n) - getit(m - 2, n) + P) % P);
	return 0;
}
原文地址:https://www.cnblogs.com/zjlcnblogs/p/11121154.html