[HNOI2008]GT考试

嘟嘟嘟


这道题刚开始我连dp方程都没设出来,现在想一想还是我对dp的理解不够深。


(dp[i][j])表示长串匹配到第(i)位,短串匹配到第(j)位时的方案数。因为题中说不让匹配成功,所以答案是(dp[n][m - 1])
但转移不好写,因为这个状态不够具体。应该在加一个条件:长串(s)[(1)~(i)]的后缀和短串的前缀最长的公共部分为(j)。这样转移就好办了。


如果还想不出来,可以想(dp[i][j])能转移到什么状态:
1.匹配成功:(dp[i][j]) -> (dp[i + 1][j + 1])
2.匹配不成功:这个时候(dp[i][j]) -> (dp[i + 1][k])。这个(k)(i + 1)这个位置填什么字符有关。
也就是说:

[dp[i][j] = sum _ {k = 0} ^ {m - 1} dp[i - 1][k] * f[k][j] ]

这个(f[k][j])表示短串的第(k)个位置有多少种方案能转移到(j)。由此可见,这个数组跟长串无关。
所以可以先预处理这个数组:用kmp即可。


然后我们就有了一个(O(nm ^ 2))的算法,交上去能得40分。


优化:
看上面的那个转移方程

[dp[i][j] = sum _ {k = 0} ^ {m - 1} dp[i - 1][k] * f[k][j] ]

发现就是一个普通的矩阵乘法。
然后我们矩阵快速幂一下就可以啦。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxm = 25;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, m, mod;
char s[maxm];

struct Mat
{
	int a[maxm][maxm];
	Mat operator * (const Mat& oth)const
	{
		Mat ret; Mem(ret.a, 0);
		for(int i = 0; i < m; ++i)
			for(int j = 0; j < m; ++j)
				for(int k = 0; k < m; ++k)
					ret.a[i][j] += a[i][k] * oth.a[k][j], ret.a[i][j] %= mod;
		return ret;
	}
}F;

Mat quickpow(Mat A, int b)
{
	Mat ret; Mem(ret.a, 0);
	for(int i = 0; i < m; ++i) ret.a[i][i] = 1;
	for(; b; b >>= 1, A = A * A)
		if(b & 1) ret = ret * A;
	return ret;
}

int nxt[maxm];
void kmp()
{
	for(int i = 2, j = 0; i <= m; ++i)
    {
    	while(j && s[j + 1] != s[i]) j = nxt[j];
    	if(s[j + 1] == s[i]) j++; nxt[i] = j;
    }
	for(int i = 0; i < m; ++i) 
		for(int j = 0; j <= 9; ++j)
      	{
			int k = i;
			while(k && s[k + 1] != j + '0') k = nxt[k];
			if(s[k + 1] == j + '0') k++;
			if(k < m) F.a[i][k]++;
      	}
}

int dp[maxn][maxm];
int main()
{
	n = read(); m = read(); mod = read();
	scanf("%s", s + 1);
	Mem(F.a, 0); kmp();
	Mat A = quickpow(F, n);
	int ans = 0;
	for(int i = 0; i < m; ++i) ans = (ans + A.a[0][i]) % mod;
	write(ans), enter;
	return 0;
}
原文地址:https://www.cnblogs.com/mrclr/p/10119118.html