BJWC2011 禁忌

题目链接


题解

多模式匹配首先建 AC 自动机,看到 (len le 10^9) 想到矩阵乘法优化。

朴素 DP

关于分割的最大值,可以贪心,只要走到一个能匹配串的点立刻返回根继续匹配就行,一定能保证最优。

以最后的结果枚举算期望显然是 ( ext{alphaset} ^ {len}) 的,显然不可取。由于期望线性,不妨算贡献。

(f[i][j]) 为长度为 (i) 的字符串,走到 AC 自动机上对应节点为 (j) 的概率。

转移就是在 AC 自动机上枚举字符集,如果转移到的点能匹配( Fail 链上有禁忌串),则返回根,设这个转移点 ((u, v)),即 (f[i][v] = sum f[i - 1][u] imes frac{1}{ ext{alphaset}})

(ans = sum_{i = 0}^{len - 1} f[i][u] imes frac{1}{ ext{alphaset}}) 满足有边 ((u, rt)) 的。

复杂度 (O(75 ext{len}))

矩阵优化

AC 自动机上的节点数最多为 (75) 个,且其实贡献是一个相加形式,矩阵优化应该是可行的。

考虑边递推每一层的同时维护 (ans),即矩阵多加一列, 即构造矩阵 ([F_i, ans] imes A = [F_{i + 1}, ans])

  • 对于一条边 ((u, v)) 贡献:(A[u][j] Leftarrow + frac{1}{ ext{alphaset}})

  • 特别地若这条边 ((u, rt)),对答案有贡献:(A[u][idx + 1] Leftarrow + frac{1}{ ext{alphaset}})

  • 注意 (ans) 本身要传递至下一层:(A[idx + 1][idx + 1] Leftarrow +1)

时间复杂度 (O(75 ^ 3log_2{ ext{len}}))

注意此题卡精度,所有地方包括 (frac{1}{ ext{alphaset}}) 都要开 long double

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;

typedef long double LD;

const int L = 80, S = 20;

int n, len, m;
int tr[L][26], fail[L], q[L], idx;
bool e[L];
char s[S];

struct Mat{
	LD w[L][L];
	int n, m;
	Mat operator * (const Mat &b) const {
		Mat c; c.n = n, c.m = b.m;
		for (int i = 0; i <= n; i++) {
			for (int j = 0; j <= c.m; j++) {
				c.w[i][j] = 0;
				for (int k = 0; k <= m; k++)
					c.w[i][j] += w[i][k] * b.w[k][j];
			}
		}
		return c;
	}

	void print() {
		puts("Matrix !");
		for (int i = 0; i <= n; i++) {
			for (int j = 0; j <= m; j++) printf("%.2Lf ", w[i][j]);
			puts("");
		}
	}
} res, A;

void insert() {
	int p = 0;
	for (int i = 1; s[i]; i++) {
		int ch = s[i] - 'a';
		if (!tr[p][ch]) tr[p][ch] = ++idx;
		p = tr[p][ch];
	}
	e[p] = true;
} 

void build() {
	int hh = 0, tt = -1;
	for (int i = 0; i < m; i++)
		if (tr[0][i]) q[++tt] = tr[0][i];
	while (hh <= tt) {
		int u = q[hh++];
		for (int i = 0; i < m; i++) {
			int v = tr[u][i];
			if (v) {
				fail[v] = tr[fail[u]][i];
				if (e[fail[v]]) e[v] = true;
				q[++tt] = v;
			} else tr[u][i] = tr[fail[u]][i];
		}
	}
}

int main() {
	scanf("%d%d%d", &n, &len, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%s", s + 1);
		insert();
	}
	build();
	res.n = 0, res.m = A.n = A.m = idx + 1;
	res.w[0][0] = 1; A.w[idx + 1][idx + 1] = 1;
	for (int u = 0; u <= idx; u++) {
		for (int i = 0; i < m; i++) {
			int v = tr[u][i];
			if (e[v]) {
				A.w[u][idx + 1] += (LD)1 / m;
				v = 0;
			}
			A.w[u][v] += (LD)1 / m;
		}
	}
	while (len) {
		if (len & 1) res = res * A;
		A = A * A;
		len >>= 1;
	}
	printf("%Lf
", res.w[0][idx + 1]);
}
原文地址:https://www.cnblogs.com/dmoransky/p/12444387.html