AC自动机学习笔记

AC自动机学习笔记

前言

由于太菜了, 不知道怎么解释自动机到底是个什么东西

所以大佬轻 D

正篇开始

本文中字符串下标从 1 开始

我们知道 KMP 是用来解决一个模式串在一个母串上匹配的问题的

那多个模式串在一个母串上匹配怎么做呢?

可以考虑 AC自动机

用这些模式串先构出一棵 Trie 树, 然后再在 Trie 树上跑类似于 KMP 的东西

我们可以考虑把一些标记打在 Trie 树上的点来满足题目的询问

好的, 那么我们现在有了一棵由模式串构成的 Trie 树

这里先要引入一个概念, fail 指针

何为 fail 指针

我们假设当前匹配到的是母串中的 (S[i - j]) , 也就是母串中 (i)(j) 的一个子串

然后跟这个串匹配的是模式串 (k_1)(C[1 - (j - i + 1)])

如果 (S[j + 1])(C[j - i + 2]) 相等, 我们就可以继续匹配

那如果不匹配呢?

如果我们找到一个其他的模式串 (k_2), 满足这个模式串的前缀和 (k_1) 的一个真后缀相等, 并且在所有的满足这个条件的 (k_i) 中, 这个相等的子串的长度最长

(k_2)(C[1 - p_1])(k_1)(C[(j - i + 2 - p_1) - (j - i + 1)]) 相等, 并且找不到一个 (k_3) 满足 (k_3) 中最大的匹配位置 (p_2 > p_1)

那么 (k_1)(C[p_1]) 在 Trie 树中的位置, 就是 (k_1)(C[j - i + 1]) 这个点的 fail 指向的点

那我们为什么要引入 fail 指针的一个概念呢?

我们知道, 假如说我们在这个点无法匹配时, 看在这个点的 fail , 能不能匹配

又由于没有比这个 fail 能够匹配更长最长真后缀的, 于是在你跳 fail 的时候不会漏掉可能合法的答案

如果你有一个字符串的前缀能够匹配这个字符, 那这个前缀对应的这个点, 一定可以通过不断跳当前点的 fail 得到

我们可以将跳 fail 的过程看作不断选择这个模式串的后缀, 砍掉前面最短的一部分, 满足其他的在这棵 Trie 树中存在

是不是有一点像 KMP 的不断跳 (nxt)

给段代码理解一下

void get_fail()
{
    for(int i = 0; i < 4; i++)
		if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]);
    //我们要先排除一个点的 fail 是他的父亲的情况, 由于你 fail 指的是一个最长真后缀, 你当前只有一个字符, 至少砍掉一个不就是没了
    while(!q.empty())
    {
		int u = q.front(); q.pop();
		for(int v, i = 0; i < 4; i++)
		{
	    	v = t[u].ch[i];
	    	if(v) t[v].fail = t[t[u].fail].ch[i], q.push(v); //存在 u 这个儿子 v , 那这个儿子 v 的 fail 指向的, 就是 u 的 fail 对应的这个点的对应的这个儿子, 可以看作在这两个字符串后面同时加了一个字符
	    	else t[u].ch[i] = t[t[u].fail].ch[i]; //不存在这个儿子 v , 就把 u 的这个儿子指向 u 的 fail 的这个儿子, 这一步并不会影响到 fail 的正确性, 相当于是我们匹配不到最长的就匹配次长的, 然后这个 u 的儿子指到 fail 的儿子相当于你先跳到 fail , 然后跳到 fail 的这个儿子, 事实上是两步, 我们把它并成一步
		}
    }
}

然后我们就把这个 AC自动机构出来了

那么接下来就直接拿母串在 AC自动机上匹配即可

那么如何维护题目所求的东西呢

我们知道到 u 点可以匹配, 那到 u 的 fail 也必然可以匹配

假如说我们要维护的是是否存在, 那么我们可以不断跳 fail 不断标记, 直到当前这个 fail 被标记过了就可以停止了

这样每个点只会被更新一次, 复杂度是对的

但如果我们要维护的是匹配次数呢

这样不断跳 fail , 必须要跳到底

如果是 (aaaaaaaaaaaaaaaaaaaaaaa) 这种串, 复杂度就不对了

那要怎么做呢

考虑每一次每个点的 fail , 只有一个, 假如说我们把每个点和他的 fail 看作树中的一对父子关系

那么每次跳 fail 可以看作在这棵 fail 树中, 从这个点不断跳他的父亲, 直到跳到根

那么对于一个点, 他被更新的次数是不是就是他子树中所有的点的总次数呢?

所以我们就可以, 在母串匹配的时候只在当前点打上标记, 然后把 fail 树建出来, 树形 DP 一下就行了

差不多就到这里吧, 还有什么东西之后再补充

贴个完整代码吧

Code1

这个是不停的跳 fail 的代码

洛谷AC自动机模板

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
const int N = 1e6 + 5; 
using namespace std;

int n, cnt;
char s[N];
struct node { int fail, cnt, ch[26]; } t[N]; 
queue<int> q; 

template < typename T >
inline T read()
{
    T x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * w; 
}

void build()
{
    int len = strlen(s + 1);
    int u = 0;
    for(int tmp, i = 1; i <= len; i++)
    {
	tmp = s[i] - 'a';
	if(!t[u].ch[tmp]) t[u].ch[tmp] = ++cnt;
	u = t[u].ch[tmp]; 
    }
    t[u].cnt++; 
}

void get_fail()
{
    for(int i = 0; i < 26; i++)
	if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]); 
    while(!q.empty())
    {
	int u = q.front(); q.pop(); 
	for(int v, i = 0; i < 26; i++)
	{
	    v = t[u].ch[i];  
	    if(v)
	    {
		t[v].fail = t[t[u].fail].ch[i];
		q.push(v); 
	    }
	    else t[u].ch[i] = t[t[u].fail].ch[i]; 
	}
    }
}

int solve()
{
    int u = 0, len = strlen(s + 1), ans = 0;
    for(int v, tmp, i = 1; i <= len; i++)
    {
	v = u = t[u].ch[tmp = s[i] - 'a'];
	while(v && t[v].cnt != -1)
	{
	    ans += t[v].cnt;
	    t[v].cnt = -1;
	    v = t[v].fail; 
	}
    }
    return ans; 
}

int main()
{
    n = read <int> ();
    for(int i = 1; i <= n; i++)
    {
	scanf("%s", s + 1);
	build(); 
    }
    scanf("%s", s + 1);
    get_fail(); 
    printf("%d
", solve()); 
    return 0; 
}

Code2

这是构 fail 树的代码

洛谷AC自动机二次加强版

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
const int N = 2e6 + 5; 
using namespace std;

int n, cnt, cnte, sz[N], head[N], match[N];
char s[N];
struct node { int fail, ch[26]; } t[N]; 
struct edge { int to, nxt; } e[N]; 
queue<int> q; 

template < typename T >
inline T read()
{
    T x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); 
    return x * w; 
}

inline void adde(int u, int v) { e[++cnte] = (edge) { v, head[u] }, head[u] = cnte; }

void get_fail()
{
    for(int i = 0; i < 26; i++)
	if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]);
    while(!q.empty())
    {
	int u = q.front(); q.pop();
	for(int v, i = 0; i < 26; i++)
	{
	    if(v = t[u].ch[i])
		t[v].fail = t[t[u].fail].ch[i], q.push(v);
	    else t[u].ch[i] = t[t[u].fail].ch[i]; 
	}
    }
}

void solve()
{
    int u = 0, len = strlen(s + 1);
    for(int tmp, i = 1; i <= len; i++)
    {
	u = t[u].ch[tmp = s[i] - 'a'];
	sz[u]++; 
    }
}

void dfs(int u)
{
    for(int v, i = head[u]; i; i = e[i].nxt)
	dfs(v = e[i].to), sz[u] += sz[v]; 
}

int main()
{
    n = read <int> ();
    for(int u = 0, len, tmp, i = 1; i <= n; i++, u = 0)
    {
	scanf("%s", s + 1), len = strlen(s + 1);
	for(int j = 1; j <= len; j++)
	{
	    if(!t[u].ch[tmp = s[j] - 'a']) t[u].ch[tmp] = ++cnt;
	    u = t[u].ch[tmp]; 
	}
	match[i] = u; 
    }
    scanf("%s", s + 1);
    get_fail(), solve();
    for(int i = 1; i <= cnt; i++)
	adde(t[i].fail, i);
    dfs(0); 
    for(int i = 1; i <= n; i++)
	printf("%d
", sz[match[i]]); 
    return 0; 
}
原文地址:https://www.cnblogs.com/ztlztl/p/12311023.html