AC 自动机学习笔记

AC 自动机(Aho-Corasick Automaton),也是一种字符串匹配算法。主要用于解决多个模式串匹配主串的问题,它的本质是用 Trie + KMP 算法。

原理

与 KMP 算法类似,主要步骤:

  1. 将所有模式串构建成一棵 Trie 树;

  2. 在 Trie 上构造所有节点的前缀指针;

  3. 利用前缀指针对主串进行匹配。

模式串与主串的匹配

与 KMP 算法完全相同。

匹配步骤:

  1. 如果当前字符匹配(ch[j][S[i + 1] - 'a'] != 0),则继续匹配下一个字符(i++, j = ch[j][S[j + 1] - 'a']);

  2. 如果当前字符失配(ch[j][S[i + 1] - 'a'] == 0),则重新对齐(j = nxt[j])直到匹配或者找到了 0。

构造节点的前缀指针

构建方式与 KMP 算法类似。

在 KMP 中,如果在模式串的一个位置失配,则需要回到模式串的前面一个位置继续匹配。从位置 $i$ 处失配后回到 $j$ 位置 ,记作 $fail(i) = j$ 。

考虑 $fail(i) = j$ 的条件:串的前 j个字符组成的前缀,是前 个字符组成前缀的后缀。理论依据是,这样可以保证每一时刻已匹配的字符尽量多,避免遗漏。

现在将问题转化为,在一棵 Trie 上,求一个节点 $j$,使得从根到 $j$ 的路径组成的串是从根到 $i$ 的路径组成串的后缀

如图(图片来自 Menci 的《AC 自动机学习笔记》):

img

设 $i$ 父节点为 $i'$, 的入边上的字母为 $c$。

一个显然的结论是,如果 $fail(i')$ 有字母 $c$ 的出边,则该出边指向的点即为 $fail(i)$。例如,上图中 $fail(7) = 1, fail(8) = 2$。

如果 $fail(i')$ 没有字母 $c$ 的出边,则沿着失配函数继续向上找,找到 $fail(fail(i'))$ …… 直到找到根为止,如果找不到一个符合条件的节点,则 $fail(i)$ 为根。 例如,上图中 $fail(3) = 0$。

BFS 优化

if (!ch[u][i]) ch[u][i] = ch[nxt[u]][i];

如果当前节点 u 不存在 i 的转移边,则创建对应儿子,并让它的指向更短前缀,即指向该节点前缀指针。

原来判断节点的转移边 i 是否存在,存在就直接赋值,否则缩短前缀继续判断直到匹配(while (v > 0 && !ch[v][i]) v = nxt[v])。根据 BFS 遍历特性,浅层的节点已经构建好了前缀指针,如果让节点不存在的转移边 i 直接指向的更短前缀。这样就无需判断,直接赋值即可。(如果节点前缀指针也没有 i 的转移边怎么办。其实前缀指针的 i 的转移边已经在浅层遍历时指向更短前缀,达到了上面 while 语句的效果。)

通过 BFS 构建步骤:

  1. 初始化所有与根节点相连的转移边(ch[0][c] = 1);

  2. 由浅到深遍历每个节点,每到达一个节点遍历它的所有转移边 i

  3. 如果节点 u 存在 i 的转移边(ch[u][i] != 0),则该节点转移边进队,且它的儿子 ch[u][i] 前缀指针指向该节点前缀指针(nxt[ch[u][i]] = ch[nxt[i]][i]);

  4. 如果节点 u 不存在 i 的转移边(ch[u][i] == 0),则让节点不存在的转移边 i 指向该节点前缀指针(ch[u][i] = ch[nxt[u]][i])。

模板

Luogu 3808

给定 $n$ 个模式串 $s_i$ 和一个文本串 $t$,求有多少个不同的模式串在文本串里出现过。
两个模式串不同当且仅当他们编号不同。

const int MAXN = 1000005;

int book[MAXN];
int ch[MAXN][30];
int nxt[MAXN], tot;
int que[MAXN], l, r;
char p[MAXN], s[MAXN];

void init() {
	tot = 0, l = 1, r = 1;
	memset(ch, 0, sizeof(ch));
	memset(nxt, 0, sizeof(nxt));
	memset(book, 0, sizeof(book));
	memset(que, 0, sizeof(que));
}

void insert(char *s) {
	int u = 0;
	int len = strlen(s);
	for (int i = 0; i < len; i++) {
		int c = s[i] - 'a';
		if (!ch[u][c]) ch[u][c] = ++tot;
		u = ch[u][c];
	}
	book[u]++;
}

void build(){
	for (int i = 0; i < 26; i++){
		if (ch[0][i]) {
			nxt[ch[0][i]] = 0;
			que[r++] = ch[0][i];
		}
	}

	while (l < r) {
		int u = que[l++];
		for (int i = 0; i < 26; i++) {
			if(!ch[u][i]) ch[u][i] = ch[nxt[u]][i];
			else {
				que[r++] = ch[u][i];
				nxt[ch[u][i]] = ch[nxt[u]][i];
			}
		}
	}
}

int query(char *s) {
	int res = 0;
	int len = strlen(s), u = 0;
	for(int i = 0; i < len; i++){
		u = ch[u][s[i] - 'a'];
		for(int k = u; k && ~book[k]; k = nxt[k]) {
			res += book[k];
			book[k] = -1;
		}
	}
	return res;
}

query 函数中:

for(int k = u; k && ~book[k]; k = nxt[k]) {
	res += book[k];
	book[k] = -1;
}

k != 0book[k] 未被标记(book[k] != -1)时执行。

原理是按照 Trie 的方式去匹配。当匹配到一个模式串时,累加它被标记的次数,并缩短后缀,继续判断是否在匹配串集中,直到根节点。每次访问直接把 book[k]累加到 res 即可,为了避免重复访问,访问过后标记为 -1

参考资料

原文地址:https://www.cnblogs.com/lcfsih/p/14391347.html