AC自动机

AC 自动机定义:多模匹配算法。相比于用KMP算法高效完成的单模匹配算法,多模大概就是有一个S(主串)和复数个T(模板串),对每一个模板串都进行一次kmp算是暴力的做法,而ac自动机就是用来高效解决这类问题的算法;

AC自动机的构造:Trie树作为搜索数据结构+Fail指针作为当前字符失配时跳转到的具有最长公共前后缀的字符的位置

例如:cat, cash, app, apple, aply, ok;

首先建立一个Trie树,用来保存所有的单词;

其中Next [ i ] [ j ]  代表编号为 i 的节点的第 j 个儿子(是儿子!是儿子!是儿子!)是编号为 k 的节点。

其中编号 i 为节点的编号,即 i 的最大个数即为节点的个数——相同的字母编号不一定相同

 编号 k 为 节点本身字母的代表符号 ,如本题仅有小写字母,故编号k最大为26个——相同的字母编号相同

 这样能节约空间,不用对每个节点都开一个26的数组

对于AC自动机,大体的思路是在树上KMP,假设我们匹配到一个失配的位置,则匹配到下一个拥有相同前缀的地方继续匹配,和KMP一样是通过传送门一样的方式提高效率的。

其中,不难得出结论,与root(根节点)相连的节点若失配,则直接指向根节点(没有前缀了);

对于其他节点,在求child的失配指针时,首先我们要找到其father的失配指针所指向的节点,假如是T节点的话,我们就要看T的child中有没有和child节点字母相同的节点,

如果有,这个节点就是child的失配指针,如果没有,则需要继续找T的失配指针,直到指向root为止;

下面举个栗子

应该很明显为什么要找父节点的失配指针了,因为可以保证前缀的最大化(虽然除了父节点也没其他的能找了),

然后父节点的失配指针再没有找到的话,就相当于本来的最大相同后缀为abcd削减到abc一样继续搜索,直到发现并没有任何一个符合;

 下面贴上模板

HDU2222为例的模板

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<queue>
using namespace std;
const int maxn = 5000010;
struct Trie
{
    int Next[maxn][26];//Next[i][j]代表编号为i的节点的第j个儿子为(k + 'a')
    int fail[maxn];
    int end[maxn];

    int root, L;
    //
    int newnode()//新建点
    {
        for (int i = 0; i < 26; i++)
            Next[L][i] = -1;
        end[L++] = 0;
        return L - 1;
    }
    void init()
    {
        L = 0;
        root = newnode();
        //root = 0,同时Next[0][x]整体赋值为-1
    }
    //
    /*字典树的构建过程是这样的,当要插入许多单词的时候,我们要从前往后遍历整个字符串
    当我们发现当前要插入的字符其节点再先前已经建成,我们直接去考虑下一个字符即可
    当我们发现当前要插入的字符没有再其前一个字符所形成的树下没有自己的节点
    我们就要创建一个新节点来表示这个字符,接下往下遍历其他的字符。然后重复上述操作。*/
    void insert(char buf[])//建立字典树
    {
        int len = strlen(buf);
        int Now = root;//当前所在的root位置
        for (int i = 0; i < len; i++)
        {
            if (Next[Now][buf[i] - 'a'] == -1)//没有自己的节点
                Next[Now][buf[i] - 'a'] = newnode();
            Now = Next[Now][buf[i] - 'a'];//将当前节点后移
        }
        end[Now]++;
    }

    void build()//构建失配函数
    {
        queue<int>Q;
        fail[root] = root;//相当于fail[0] = 0;
     //初始化队列
        for (int i = 0; i < 26; i++)
            if (Next[root][i] == -1)
                Next[root][i] = root;
            else
            {
                fail[Next[root][i]] = root;
                Q.push(Next[root][i]);
            }
     //按bfs顺序计算失配函数
        while (!Q.empty())
        {
            int now = Q.front();
            Q.pop();
            for (int i = 0; i < 26; i++)
                if (Next[now][i] == -1)
                    Next[now][i] = Next[fail[now]][i];
                else
                {
                    fail[Next[now][i]] = Next[fail[now]][i];
                    Q.push(Next[now][i]);
                }
        }
    }
    int query(char buf[])
    {
        int len = strlen(buf);
        int now = root;
        int res = 0;
        for (int i = 0; i < len; i++)
        {
            now = Next[now][buf[i] - 'a'];
            int temp = now;
            while (temp != root)
            {
                res += end[temp];
                end[temp] = 0;
                temp = fail[temp];
            }
        }
        return res;
    }
    /*void debug()
    {
      for (int i = 0; i < L; i++)
      {
        printf("id = %3d,fail = %3d,end = %3d,chi = [", i, fail[i], end[i]);
        for (int j = 0; j < 26; j++)
        printf("%2d", Next[i][j]);
        printf("] ");
      }
    }*/
};


char buf[maxn];
Trie ac;
int main()
{
    int T;
    int n;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n);
        ac.init();
        for (int i = 0; i < n; i++)
        {
            scanf("%s", buf);
            ac.insert(buf);
        }
        ac.build();
        scanf("%s", buf);
        printf("%d ", ac.query(buf));
    }
}
//这里是优化过的代码,优化的话主要的方向是减少失配的时候跳的次数(因为反复沿着失配边走比较复杂),所以会把不存在的边补上(有点像强联通的思路,即能够到达的直接连上)
 

 感谢大佬的讲解:https://www.cnblogs.com/TheRoadToTheGold/p/6290732.html //字典树部分讲解

接下来放一个用后缀链接优化过的真*模板以及各功能块的使用

#include<iostream>
#include<queue>
#include<cstring>
#include<cstdio>
using namespace std;
const int maxn = 1000050;
const int size_all = 26;
struct Trie
{
    int Next[maxn][size_all];
    int end[maxn];//标记该节点是否为尾节点
    int last[maxn];//last后缀链接优化
    int fail[maxn];//fail指针
    int cnt[maxn];//计数
    int ans = 0;

    int root;
    void init()
    {
        root = 1;
        memset(Next[0], 0, sizeof(Next[0]));
        end[0] = last[0] = fail[0] = 0;
    }
    void insert(char *s)
    {
        int n = strlen(s);
        int u = 0;
        for (int i = 0; i < n; i++)
        {
            int id = s[i] - 'a';//依据题目情况定
            if (Next[u][id] == 0)
            {
                Next[u][id] = root;
                memset(Next[root], 0, sizeof(Next[root]));
                end[root++] = 0;
            }
            u = Next[u][id];
        }
        end[u] = 1;//标记为终止节点代表有以这个为结尾的单词
        cnt[u] = 0;
    }
    void print(int i)//不要在意名字,具体就是用来解决问题的一环
    {
        //此处是用来统计出现的次数(允许出现重叠)
        if (end[i])
        {
            cnt[i]++;
            print(last[i]);
        }
    }
    void find(char *s)//开始匹配主串
    {
        int n = strlen(s);
        int j = 0;
        for (int i = 0; i < n; i++)
        {
            int id = s[i] - 'a';
            //while(j && !next[j][id] ) j = fail[j];
            j = Next[j][id];
            if (end[j]) print(j);
            else if (last[j]) print(last[j]);
        }
    }
    void build()
    {
        queue<int>q;
        for (int i = 0; i < size_all; i++)
        {
            int u = Next[0][i];
            if (u)
            {
                last[u] = fail[u] = 0;
                q.push(u);
            }

        }
        while (!q.empty())
        {
            int r = q.front();
            q.pop();
            for (int i = 0; i < size_all; i++)
            {
                int u = Next[r][i];
                //if(!u) continue;
                if (!u) {
                    Next[r][i] = Next[fail[r]][i];
                    continue;
                }
                q.push(u);
                int v = fail[r];
                while (v && Next[v][i] == 0) v = fail[v];
                fail[u] = Next[v][i];
                last[u] = (end[fail[u]]) ? fail[u] : last[fail[u]];
            }
        }
    }
};
Trie ac;
int cas;
int n;
char buf[maxn];
int main()
{
    scanf("%d", &cas);
    while (cas--)
    {
        scanf("%d", &n);
        ac.init();
        for (int i = 0; i < n; i++)
        {
            scanf("%s", buf);
            ac.insert(buf);
        }
        ac.build();
        scanf("%s", buf);
        ac.find(buf);
        //solve()
    }
}

各种题目的应用以及AC代码:

HDU2222 略

HDU2896 病毒侵袭 ---- 模板题

#include<cstdio>
#include<iostream>
#include<cstring>
#include<queue>
#include<map>
#include<set>
#include<string>

using namespace std;
const int size = 128;
const int maxn = 100005;
struct Trie
{
    int Next[maxn][size];
    int end[maxn];
    int last[maxn];
    int fail[maxn];
    int root;
    int _[maxn];
    set<int> ans;
    void init()
    {
        root = 1;
        memset(Next[0],0,sizeof(Next[0]));
        end[0] = last[0] = fail[0] = 0;
    }
    void insert(char *s, int pos)
    {
        int n =strlen(s);
        int u = 0;
        for(int i = 0;i < n;i++)
        {
            int id = s[i];
            if(Next[u][id] == 0)
            {
                Next[u][id] = root;
                memset(Next[root], 0, sizeof(Next[root]));
                end[root++] = 0;
            }
            u = Next[u][id];
        }
        end[u] = 1;
        _[u] = pos;
    }
    void print(int i)
    {
        if(end[i])
        {
            ans.insert(_[i]);
            print(last[i]);
        }
    }
    void find(char *s)
    {
        int n = strlen(s);

        int j = 0;
        for(int i = 0;i < n;i++)
        {
            int id = s[i];
            //while( j && !Next[j][id]) j = fail[j];
            j = Next[j][id];
            if(end[j]) print(j);
            else if(last[j]) print(last[j]);
        }
    }
    void build()
    {
        queue<int>q;
        for(int i = 0;i < size;i++)
        {
            int u = Next[0][i];
            if(u)
            {
                last[u] = fail[u] = 0;
                q.push(u);
            }
        }
        while(!q.empty())
        {
            int r = q.front();
            q.pop();
            for(int i = 0;i < size;i++)
            {
                int u = Next[r][i];
                if(!u)
                {
                    Next[r][i] = Next[fail[r]][i];
                    continue;
                }
                q.push(u);
                int v = fail[r];
                while(v && Next[v][i] == 0) v = fail[v];
                fail[u] = Next[v][i];
                last[u] = (end[fail[u]]) ? fail[u] : last[fail[u]];
            }
        }

    }


};
Trie ac;

char str[maxn];
int N,M;
int main()
{
    scanf("%d", &N);
    ac.init();

    for(int i = 0;i < N;i++)
    {
        scanf("%s", str);
        ac.insert(str, i);
    }
    ac.build();
    scanf("%d",&M);
    int sum = 0;
    for(int i = 0;i < M;i++)
    {
        scanf("%s",str);
        ac.find(str);
        if(ac.ans.size())
        {
            printf("web %d:", i + 1);
            for(set<int>::iterator it = ac.ans.begin(); it != ac.ans.end();it++)
            {
                printf(" %d", *it + 1);
            }
            printf("
");
            ac.ans.clear();
            sum++;
        }

    }
    printf("total: %d
", sum);
}
View Code

HDU 3065 病毒侵袭持续中 模板题

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<cstring>
using namespace std;
const int maxn = 600005;
const int size = 128;
char arr[1005][55];
struct Trie{
    int Next[maxn][size];
    int end[maxn];
    int last[maxn];
    int fail[maxn];
    int cnt[maxn];
    int lis[maxn];
    int root;
    void init()
    {
        root = 1;
        memset(Next[0], 0, sizeof(Next[0]));
        end[0] = last[0] = fail[0] = 0;
    }

    void insert(char *s, int pos)
    {
        int n = strlen(s);
        int u = 0;
        for(int i = 0; i < n;i++)
        {
            int id = s[i];
            if(Next[u][id] == 0)
            {
                Next[u][id] = root;
                memset(Next[root], 0,sizeof(Next[root]));
                end[root++] = 0;
            }
            u = Next[u][id];
        }

        end[u] = 1;
        cnt[u] = 0;    
        lis[pos] = u;
    }
    void print(int i)
    {
        if(end[i])
        {
            cnt[i]++;
            print(last[i]);
        }
    }
    void find(char *s)
    {
        int n = strlen(s);
        int j = 0;
        for(int i = 0; i < n;i++)
        {
            int id = s[i];
            //while(j && !Next[j][id]) j = fail[j];
            j = Next[j][id];
            if(end[j]) print(j);
            else if(last[j]) print(last[j]);
        }
    }
    void build()
    {
        queue<int>q;
        for(int i = 0; i < size; i++)
        {
            int u = Next[0][i];
            if(u)
            {
                last[u] = fail[u] = 0;
                q.push(u);

            }
        }
        while(!q.empty())
        {
            int r = q.front();
            q.pop();
            for(int i = 0;i < size;i++)
            {
                int u = Next[r][i];
                //if(!u) continue;
                if(!u)
                {
                    Next[r][i] = Next[fail[r]][i];
                    continue;
                }
                q.push(u);
                int v = fail[r];
                while(v && Next[v][i] == 0)
                    v = fail[v];
                fail[u] = Next[v][i];
                last[u] = (end[fail[u]]) ? fail[u] : last[fail[u]];

            }

        }

    }

};
Trie ac;
char mai[2000005];
int main()
{

    int n;
    while(scanf("%d", &n)!= EOF)
    {
        ac.init();
        for(int i = 0;i < n;i++)
        {
            scanf("%s", arr[i]);
            ac.insert(arr[i], i);
        }
        scanf("%s", mai);
        ac.build();
        ac.find(mai);
        for(int i = 0;i < n;i++)
        {
            int pos = ac.lis[i];
            if(ac.cnt[pos])
            {
                printf("%s: %d
", arr[i], ac.cnt[pos]);
            }
        }
    }    
}
View Code

Zoj 3228 search the string 模板题

#include<iostream>
#include<queue>
#include<cstring>
#include<cstdio>
using namespace std;
const int maxn = 600005;
const int sigma_size = 26;
struct Trie
{
    int Next[maxn][sigma_size];
    int end[maxn];
    int time[maxn], len[maxn];
    int last[maxn], fail[maxn];
    int cnt[maxn][2];
    int root;
    void init()
    {
        root = 1;
        memset(Next[0], 0, sizeof(Next[0]));
        end[0] = last[0] = fail[0] = 0;
    }

    void insert(char *s)
    {
        int n = strlen(s), u = 0;
        for(int i = 0;i < n;i++)
        {
            int id = s[i] - 'a';
            if(Next[u][id] == 0)
            {
                Next[u][id] = root;
                memset(Next[root], 0, sizeof(Next[root]));
                end[root++] = 0;
            }
            u = Next[u][id];
        }
        end[u] = 1;
        time[u] = 0;
        len[u] = n;
        cnt[u][0] = cnt[u][1] = 0;
    }
    void print(int i, int pos)
    {
        if(end[i])
        {
            cnt[i][0]++;
            if(time[i] + len[i] <= pos)
            {
                time[i] = pos;
                cnt[i][1]++;
            }
            print(last[i], pos);
        }
    }
    void find(char *s)
    {
        int n = strlen(s), j = 0;
        for(int i = 0;i < n;i++)
        {
            int id = s[i] - 'a';
            while(j && !Next[j][id]) j = fail[j];
            j = Next[j][id];
            if(end[j]) print(j, i + 1);
            else if(last[j]) print(last[j], i + 1);
        }
    }
    int find_T(char *s, int type)
    {
        int n = strlen(s), u = 0;
        for(int i = 0;i < n;i++)
        {
            int id = s[i] - 'a';
            u = Next[u][id];
        }
        return cnt[u][type];
    }
    void build()
    {
        queue<int>q;
        for(int i = 0;i < sigma_size;i++)
        {
            int u = Next[0][i];
            if(u)
            {
                last[u] = fail[u] = 0;
                q.push(u);
            }
        }
        while(!q.empty())
        {
            int r = q.front();
            q.pop();
            for(int i = 0;i < sigma_size; i++)
            {
                int u = Next[r][i];
                if(!u) continue;
                q.push(u);
                int v = fail[r];
                while(v && Next[v][i] == 0) v = fail[v];
                fail[u] = Next[v][i];
                last[u] = (end[fail[u]]) ? fail[u] :last[fail[u]];
            }
        }
    }
};
Trie ac;
const int maxstring = 100000 + 1000;
char str[maxstring];
char t[maxstring][7];
int type[maxstring];
int main()
{
    int kase = 0;
    while(~scanf("%s",str))
    {
        ac.init();
        int n;
        scanf("%d",&n);
        for(int i = 0;i < n;i++)
        {
            scanf("%d%s",&type[i], t[i]);
            ac.insert(t[i]);
        }
        ac.build();
        ac.find(str);
        printf("Case %d
", ++kase);
        for(int i = 0;i < n;i++)
        {
            printf("%d
", ac.find_T(t[i], type[i]));
        }
        printf("
");
    }
}
View Code

 //    推荐题目POJ2778  POJ1625(这个比较恶心)  

原文地址:https://www.cnblogs.com/TheStuckedCat/p/9361109.html