AC自动机及其模板

模板

#include<queue>
#include<stdio.h>
#include<string.h>
using namespace std;

const int Max_Tot = 5e5 + 10;
const int Max_Len = 1e6 + 10;
const int Letter  = 26;

struct Aho{
    struct StateTable{
        int Next[Letter];
        int fail, cnt;
    }Node[Max_Tot];
    int Size;
    queue<int> que;

    inline void init(){
        while(!que.empty()) que.pop();
        memset(Node[0].Next, 0, sizeof(Node[0].Next));
        Node[0].fail = Node[0].cnt = 0;
        Size = 1;
    }

    inline void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i=0; i<len; i++){
            int idx = s[i] - 'a';
            if(!Node[now].Next[idx]){
                memset(Node[Size].Next, 0, sizeof(Node[Size].Next));
                Node[Size].fail = Node[Size].cnt = 0;
                Node[now].Next[idx] = Size++;
            }
            now = Node[now].Next[idx];
        }
        Node[now].cnt++;
    }

    inline void BuildFail(){
        Node[0].fail = -1;
        que.push(0);
        while(!que.empty()){
            int top = que.front();  que.pop();
            for(int i=0; i<Letter; i++){
                if(Node[top].Next[i]){
                    if(top == 0) Node[ Node[top].Next[i] ].fail = 0;
                    else{
                        int v = Node[top].fail;
                        while(v != -1){
                            if(Node[v].Next[i]){
                                Node[ Node[top].Next[i] ].fail = Node[v].Next[i];
                                break;
                            }v = Node[v].fail;
                        }if(v == -1) Node[ Node[top].Next[i] ].fail = 0;
                    }que.push(Node[top].Next[i]);
                }
            }
        }
    }

    inline void Get(int u, int &res){
        while(u){
            res += Node[u].cnt;
            Node[u].cnt = 0;
            u = Node[u].fail;
        }
    }

    int Match(char *s){
        int len = strlen(s);
        int res = 0, now = 0;
        for(int i=0; i<len; i++){
            int idx = s[i] - 'a';
            if(Node[now].Next[idx]) now = Node[now].Next[idx];
            else{
                int p = Node[now].fail;
                while(p!=-1 && Node[p].Next[idx]==0) p = Node[p].fail;
                if(p == -1) now = 0;
                else now = Node[p].Next[idx];
            }
            if(Node[now].cnt) Get(now, res);
        }
        return res;
    }
}ac;

char S[Max_Len];
int main(void)
{
//    ac.init();
//    ac.BuildFail();
//    ac.Match();
//    .....
    return 0;
}
View Code
#include<bits/stdc++.h>
using namespace std;

#define MAX_N 1000006  /// 主串长度
#define MAX_Tot 500005 /// 字典树上可能的最多的结点数 = Max串数 * Max串长

struct Aho{
    struct state{
        int next[26];
        int fail,cnt;
    }st[MAX_Tot]; /// 节点结构体
    int Size; /// 节点个数
    queue<int> que;/// BFS构建fail指针的队列

    void init(){
        while(que.size())que.pop();/// 清空队列
        for(int i=0;i<MAX_Tot;i++){/// 初始化节点,有时候 MLE 的时候,可以尝试将此初始化放到要操作的时候再来初始化
            memset(st[i].next,0,sizeof(st[i].next));
            st[i].fail=st[i].cnt=0;
        }
        Size=1;/// 本来就有一个空的根节点
    }

    void insert(char *S){/// 插入模式串
        int len=strlen(S);/// 复杂度为O(n),所以别写进for循环
        int now=0;/// 当前结点是哪一个,从0即根开始
        for(int i=0;i<len;i++){
            char c = S[i];
            if(!st[now].next[c-'a']) st[now].next[c-'a']=Size++;
            now=st[now].next[c-'a'];
        }
        st[now].cnt++;/// 给这个串末尾打上标记
    }

    void build(){/// 构建 fail 指针
        st[0].fail=-1;/// 根节点的 fail 指向自己
        que.push(0);/// 将根节点入队

        while(que.size()){
            int top = que.front(); que.pop();

            for(int i=0; i<26; i++){
                if(st[top].next[i]){/// 如果当前节点有 i 这个儿子
                    if(top == 0) st[st[top].next[i]].fail=0;/// 第二层节点 fail 应全指向根
                    else {
                        int v = st[top].fail;/// 走向 top 节点父亲的 fail 指针指向的地方,尝试找一个最长前缀
                        while(v != -1){/// 如果走到 -1 则说明回到根了
                            if(st[v].next[i]){/// 如果有一个最长前缀后面接着的也是 i 这个字符,则说明 top->next[i] 的 fail 指针可以指向这里
                                st[st[top].next[i]].fail = st[v].next[i];
                                break;/// break 保证找到的前缀是最长的
                            }
                            v = st[v].fail;/// 否则继续往父亲的父亲的 fail 跳,即后缀在变短( KMP 思想 )
                        } if(v == -1) st[st[top].next[i]].fail=0;/// 如果从头到尾都没找到,那么就只能指向根了
                    } que.push(st[top].next[i]);/// 将这个节点入队,为了下面建立 fail 节点做准备
                }
            }
        }
    }

    int get(int u){
        int res = 0;
        while(u){
            res = res + st[u].cnt;
            st[u].cnt = 0;
            u = st[u].fail;
        }
        return res;
    }

    int match(char *S){
        int len = strlen(S);/// 主串长度
        int res=0,now=0;/// 主串能够和多少个模式串匹配的结果、当前的节点是哪一个
        for(int i=0; i<len; i++){
            char c = S[i];
            if(st[now].next[c-'a']) now=st[now].next[c-'a'];/// 如果匹配了,则不用跳到 fail 处,直接往下一个字符匹配
            else {
                int p = st[now].fail;
                while( p!=-1 && st[p].next[c-'a']==0 ) p=st[p].fail;/// 跳到 fail 指针处去匹配 c-'a' ,直到跳到 -1 也就是没得跳的时候
                if(p == -1) now = 0;/// 如果全部都不匹配,只能回到根节点了
                else now = st[p].next[c-'a'];/// 否则当前节点就是到达了能够匹配的 fail 指针指向处
            }
            if(st[now].cnt)/// 如果当前节点是个字符串的结尾,这个时候就能统计其对于答案贡献了,答案的贡献应该是它自己 + 它所有 fail 指针指向的节点
                           /// 实际也就是它匹配了,那么它的 fail 指向的前缀以及 fail 的 fail 实际上也应该是匹配了,所以循环跳 fail 直到无法再跳为止
                res = res + get(now);
        }
        return res;
    }
}ac;

int T;
int N;
char S[MAX_N];
int main(){
    // ac.init();
    // ac.build();
    // ac.match();
    // ...
    return 0;
}
带注释
#include <stdio.h>  
#include <stdlib.h>  
#include <string.h>  
struct Node  
{  
    int cnt;//是否为该单词的最后一个结点   
    Node *fail;//失败指针   
    Node *next[26];//Trie中每个结点的各个节点   
}*queue[500005];//队列,方便用BFS构造失败指针   
char s[1000005];//主字符串   
char keyword[55];//需要查找的单词   
Node *root;//头结点   
void Init(Node *root)//每个结点的初始化   
{  
    root->cnt=0;  
    root->fail=NULL;  
    for(int i=0;i<26;i++)  
        root->next[i]=NULL;  
}  
void Build_trie(char *keyword)//构建Trie树   
{  
    Node *p,*q;  
    int i,v;  
    int len=strlen(keyword);  
    for(i=0,p=root;i<len;i++)  
    {  
        v=keyword[i]-'a';  
        if(p->next[v]==NULL)  
        {  
            q=(struct Node *)malloc(sizeof(Node));  
            Init(q);  
            p->next[v]=q;//结点链接   
        }  
        p=p->next[v];//指针移动到下一个结点   
    }  
    p->cnt++;//单词最后一个结点cnt++,代表一个单词   
}  
void Build_AC_automation(Node *root)  
{  
    int head=0,tail=0;//队列头、尾指针   
    queue[head++]=root;//先将root入队   
    while(head!=tail)  
    {  
        Node *p=NULL;  
        Node *temp=queue[tail++];//弹出队头结点   
        for(int i=0;i<26;i++)  
        {  
            if(temp->next[i]!=NULL)//找到实际存在的字符结点   
            { //temp->next[i] 为该结点,temp为其父结点   
                if(temp==root)//若是第一层中的字符结点,则把该结点的失败指针指向root   
                    temp->next[i]->fail=root;  
                else  
                {  
                    //依次回溯该节点的父节点的失败指针直到某节点的next[i]与该节点相同,  
                    //则把该节点的失败指针指向该next[i]节点;   
                    //若回溯到 root 都没有找到,则该节点的失败指针指向 root  
                    p=temp->fail;//将该结点的父结点的失败指针给p   
                    while(p!=NULL)  
                    {  
                        if(p->next[i]!=NULL)  
                        {  
                            temp->next[i]->fail=p->next[i];  
                            break;  
                        }  
                        p=p->fail;  
                    }  
                    //让该结点的失败指针也指向root   
                    if(p==NULL)  
                        temp->next[i]->fail=root;  
                }  
                queue[head++]=temp->next[i];//每处理一个结点,都让该结点的所有孩子依次入队   
            }  
        }  
    }  
}  
int query(Node *root)  
{ //i为主串指针,p为模式串指针   
    int i,v,count=0;  
    Node *p=root;  
    int len=strlen(s);  
    for(i=0;i<len;i++)  
    {  
        v=s[i]-'a';  
        //由失败指针回溯查找,判断s[i]是否存在于Trie树中   
        while(p->next[v]==NULL && p!=root)  
            p=p->fail;  
        p=p->next[v];//找到后p指针指向该结点   
        if(p==NULL)//若指针返回为空,则没有找到与之匹配的字符   
            p=root;  
        Node *temp=p;//匹配该结点后,沿其失败指针回溯,判断其它结点是否匹配   
        while(temp!=root)//匹配结束控制   
        {  
            if(temp->cnt>=0)//判断该结点是否被访问   
            {  
                count+=temp->cnt;//由于cnt初始化为 0,所以只有cnt>0时才统计了单词的个数   
                temp->cnt=-1;//标记已访问过   
            }  
            else//结点已访问,退出循环   
                break;  
            temp=temp->fail;//回溯 失败指针 继续寻找下一个满足条件的结点   
        }  
    }  
    return count;  
}  
int main()  
{  
    int T,n;  
    scanf("%d",&T);  
    while(T--)  
    {  
        root=(struct Node *)malloc(sizeof(Node));  
        Init(root);  
        scanf("%d",&n);  
        for(int i=0;i<n;i++)  
        {  
            scanf("
%s",keyword);  
            Build_trie(keyword);  
        }  
        Build_AC_automation(root);  
        scanf("
%s",s);  
        printf("%d
",query(root));  
    }  
    return 0;  
}  
指针版
const int max_node = 1e6 + 10;
const int max_len = 1e5 + 10;
const int Letter  = 26;

struct Aho{
    struct StateTable{
        int nxt[Letter];
        int fail, cnt;
        bool vis;
        void init(){
            memset(nxt, 0, sizeof(nxt));
            fail = 0;
            cnt = 0;
            vis = false;
        }
    }Node[max_node];

    int sz;
    queue<int> que;

    inline void init(){ while(!que.empty())que.pop(); Node[0].init(); sz = 1; }

    inline void insert(char *s, int len){
        int now = 0;
        for(int i=0; i<len; i++){
            int idx = s[i] - 'a';
            if(!Node[now].nxt[idx]){
                Node[sz].init();
                Node[now].nxt[idx] = sz++;
            }
            now = Node[now].nxt[idx];
        }
        Node[now].cnt++;
    }

    inline void build(){
        Node[0].fail = -1;
        que.push(0);
        while(!que.empty()){
            int top = que.front();  que.pop();
            for(int i=0; i<Letter; i++){
                if(Node[top].nxt[i]){
                    if(top == 0) Node[ Node[top].nxt[i] ].fail = 0;
                    else{
                        int v = Node[top].fail;
                        while(v != -1){
                            if(Node[v].nxt[i]){
                                Node[ Node[top].nxt[i] ].fail = Node[v].nxt[i];
                                break;
                            }v = Node[v].fail;
                        }if(v == -1) Node[ Node[top].nxt[i] ].fail = 0;
                    }que.push(Node[top].nxt[i]);
                }else Node[top].nxt[i] = top!=0?Node[ Node[top].fail ].nxt[i]:0;
            }
        }
    }

    int Match(char *s){
        int now = 0, res = 0;
        for(int i=0; s[i]!=''; i++){
            int idx = s[i] - 'a';
            now = Node[now].nxt[idx];
            int tmp = now;
            while(tmp != 0 && !Node[tmp].vis){
                res += Node[tmp].cnt;
                Node[tmp].vis = true;
                Node[tmp].cnt = 0;
                tmp = Node[tmp].fail;
            }
        }
        return res;
    }
}ac;
Trie 图

参考博客

http://blog.csdn.net/niushuai666/article/details/7002823
http://blog.csdn.net/silence401/article/details/52662605
http://blog.csdn.net/liu940204/article/details/51345954
http://blog.csdn.net/creatorx/article/details/71100840

相关题目

HDU 2222

题意 : 给出 n 个模式串再给出一个主串,问你有多少个模式串曾在这个主串上出现过

分析 : 模板题,注意每一次计数完成后要将 cnt 的值置为 0 以免重复计算

#include<bits/stdc++.h>
using namespace std;

#define MAX_N 1000006  /// 主串长度
#define MAX_Tot 500005 /// 字典树上可能的最多的结点数 = Max串数 * Max串长

struct Aho{
    struct state{
        int next[26];
        int fail,cnt;
    }st[MAX_Tot]; /// 节点结构体
    int Size; /// 节点个数
    queue<int> que;/// BFS构建fail指针的队列

    void init(){
        while(que.size())que.pop();/// 清空队列
        for(int i=0;i<MAX_Tot;i++){/// 初始化节点,有时候 MLE 的时候,可以尝试将此初始化放到要操作的时候再来初始化
            memset(st[i].next,0,sizeof(st[i].next));
            st[i].fail=st[i].cnt=0;
        }
        Size=1;/// 本来就有一个空的根节点
    }

    void insert(char *S){/// 插入模式串
        int len=strlen(S);/// 复杂度为O(n),所以别写进for循环
        int now=0;/// 当前结点是哪一个,从0即根开始
        for(int i=0;i<len;i++){
            char c = S[i];
            if(!st[now].next[c-'a']) st[now].next[c-'a']=Size++;
            now=st[now].next[c-'a'];
        }
        st[now].cnt++;/// 给这个串末尾打上标记
    }

    void build(){/// 构建 fail 指针
        st[0].fail=-1;/// 根节点的 fail 指向自己
        que.push(0);/// 将根节点入队

        while(que.size()){
            int top = que.front(); que.pop();

            for(int i=0; i<26; i++){
                if(st[top].next[i]){/// 如果当前节点有 i 这个儿子
                    if(top == 0) st[st[top].next[i]].fail=0;/// 第二层节点 fail 应全指向根
                    else {
                        int v = st[top].fail;/// 走向 top 节点父亲的 fail 指针指向的地方,尝试找一个最长前缀
                        while(v != -1){/// 如果走到 -1 则说明回到根了
                            if(st[v].next[i]){/// 如果有一个最长前缀后面接着的也是 i 这个字符,则说明 top->next[i] 的 fail 指针可以指向这里
                                st[st[top].next[i]].fail = st[v].next[i];
                                break;/// break 保证找到的前缀是最长的
                            }
                            v = st[v].fail;/// 否则继续往父亲的父亲的 fail 跳,即后缀在变短( KMP 思想 )
                        } if(v == -1) st[st[top].next[i]].fail=0;/// 如果从头到尾都没找到,那么就只能指向根了
                    } que.push(st[top].next[i]);/// 将这个节点入队,为了下面建立 fail 节点做准备
                }
            }
        }
    }

    int get(int u){
        int res = 0;
        while(u){
            res = res + st[u].cnt;
            st[u].cnt = 0;
            u = st[u].fail;
        }
        return res;
    }

    int match(char *S){
        int len = strlen(S);/// 主串长度
        int res=0,now=0;/// 主串能够和多少个模式串匹配的结果、当前的节点是哪一个
        for(int i=0; i<len; i++){
            char c = S[i];
            if(st[now].next[c-'a']) now=st[now].next[c-'a'];/// 如果匹配了,则不用跳到 fail 处,直接往下一个字符匹配
            else {
                int p = st[now].fail;
                while( p!=-1 && st[p].next[c-'a']==0 ) p=st[p].fail;/// 跳到 fail 指针处去匹配 c-'a' ,直到跳到 -1 也就是没得跳的时候
                if(p == -1) now = 0;/// 如果全部都不匹配,只能回到根节点了
                else now = st[p].next[c-'a'];/// 否则当前节点就是到达了能够匹配的 fail 指针指向处
            }
            if(st[now].cnt)/// 如果当前节点是个字符串的结尾,这个时候就能统计其对于答案贡献了,答案的贡献应该是它自己 + 它所有 fail 指针指向的节点
                           /// 实际也就是它匹配了,那么它的 fail 指向的前缀以及 fail 的 fail 实际上也应该是匹配了,所以循环跳 fail 直到无法再跳为止
                res = res + get(now);
        }
        return res;
    }
}aho;

int T;
int N;
char S[MAX_N];
int main(){
    scanf("%d",&T);
    while(T--){
        aho.init();
        scanf("%d",&N);
        for(int i=0;i<N;i++){
            scanf("%s",S);
            aho.insert(S);
        }
        aho.build();
        scanf("%s",S);
        printf("%d
",aho.match(S));
    }
    return 0;
}
View Code

HDU 2896

题意 : 中文就不赘述了……

分析 : 模板题,可见的ascii码范围的话,直接开到128即可

#include<queue>
#include<stdio.h>
#include<string.h>
using namespace std;

const int Max_Tot = 1e5 + 10;
const int Max_Len = 1e4 + 10;
const int Letter  = 128;

struct Aho{
    struct StateTable{
        int Next[Letter];
        int fail, id;
    }Node[Max_Tot];
    int Size;
    queue<int> que;

    inline void init(){
        while(!que.empty()) que.pop();
        memset(Node[0].Next, 0, sizeof(Node[0].Next));
        Node[0].fail = Node[0].id = 0;
        Size = 1;
    }

    inline void insert(char *s, const int id){
        int len = strlen(s);
        int now = 0;
        for(int i=0; i<len; i++){
            int idx = s[i];
            if(!Node[now].Next[idx]){
                memset(Node[Size].Next, 0, sizeof(Node[Size].Next));
                Node[Size].fail = Node[Size].id = 0;
                Node[now].Next[idx] = Size++;
            }
            now = Node[now].Next[idx];
        }
        Node[now].id = id;
    }

    inline void BuildFail(){
        Node[0].fail = -1;
        que.push(0);
        while(!que.empty()){
            int top = que.front();  que.pop();
            for(int i=0; i<Letter; i++){
                if(Node[top].Next[i]){
                    if(top == 0) Node[ Node[top].Next[i] ].fail = 0;
                    else{
                        int v = Node[top].fail;
                        while(v != -1){
                            if(Node[v].Next[i]){
                                Node[ Node[top].Next[i] ].fail = Node[v].Next[i];
                                break;
                            }v = Node[v].fail;
                        }if(v == -1) Node[ Node[top].Next[i] ].fail = 0;
                    }que.push(Node[top].Next[i]);
                }
            }
        }
    }

    inline void Get(int u, bool *used){
        while(u){
            if(!used[Node[u].id] && Node[u].id)
                used[Node[u].id] = true;
            u = Node[u].fail;
        }
    }

    bool Match(char *s, bool *used){
        int now = 0;
        bool ok = false;
        for(int i=0; s[i]; i++){
            int idx = s[i];
            if(Node[now].Next[idx]) now = Node[now].Next[idx];
            else{
                int p = Node[now].fail;
                while(p!=-1 && Node[p].Next[idx]==0){
                    p = Node[p].fail;
                }
                if(p == -1) now = 0;
                else now = Node[p].Next[idx];
            }
            if(Node[now].id) { Get(now, used); ok = true; }
        }
        if(ok) return true;
        return false;
    }
}ac;

char S[Max_Len];
bool used[501];
int main(void)
{
    int n, m;
    memset(used, false, sizeof(used));
    while(~scanf("%d", &n)){
        ac.init();
        for(int i=1; i<=n; i++){
            scanf("%s", S);
            ac.insert(S, i);
        }
        ac.BuildFail();
        int ans = 0;
        scanf("%d", &m);
        for(int i=1; i<=m; i++){
            scanf("%s", S);
            if(ac.Match(S, used)){
                printf("web %d:", i);
                for(int j=1; j<=n; j++){
                    if(used[j]){
                        printf(" %d", j);
                        used[j] = false;
                    }
                }puts("");
                ans++;
            }
        }
        printf("total: %d
", ans);
    }
    return 0;
}
View Code

HDU 3065

题意 : 中文就不赘述了......

分析 : 还是模板题

#include<queue>
#include<stdio.h>
#include<string.h>
using namespace std;

const int Max_Tot = 5e5 + 10;
const int Max_Len = 2e6 + 10;
const int Letter  = 26;

struct Aho{
    struct StateTable{
        int Next[Letter];
        int fail, id;
    }Node[Max_Tot];
    int Size;
    queue<int> que;

    inline void init(){
        while(!que.empty()) que.pop();
        memset(Node[0].Next, 0, sizeof(Node[0].Next));
        Node[0].fail = Node[0].id = 0;
        Size = 1;
    }

    inline void insert(char *s, int id){
        int len = strlen(s);
        int now = 0;
        for(int i=0; i<len; i++){
            int idx = s[i] - 'A';
            if(!Node[now].Next[idx]){
                memset(Node[Size].Next, 0, sizeof(Node[Size].Next));
                Node[Size].fail = Node[Size].id = 0;
                Node[now].Next[idx] = Size++;
            }
            now = Node[now].Next[idx];
        }
        Node[now].id = id;
    }

    inline void BuildFail(){
        Node[0].fail = -1;
        que.push(0);
        while(!que.empty()){
            int top = que.front();  que.pop();
            for(int i=0; i<Letter; i++){
                if(Node[top].Next[i]){
                    if(top == 0) Node[ Node[top].Next[i] ].fail = 0;
                    else{
                        int v = Node[top].fail;
                        while(v != -1){
                            if(Node[v].Next[i]){
                                Node[ Node[top].Next[i] ].fail = Node[v].Next[i];
                                break;
                            }v = Node[v].fail;
                        }if(v == -1) Node[ Node[top].Next[i] ].fail = 0;
                    }que.push(Node[top].Next[i]);
                }
            }
        }
    }

    inline void Get(int u, int *arr){
        while(u){
            if(Node[u].id) arr[Node[u].id]++;
            u = Node[u].fail;
        }
    }

    inline void Match(char *s, int *arr){
        int now = 0;
        for(int i=0; s[i]; i++){
            if(!(s[i] >= 'A' && s[i] <= 'Z')){ now = 0; continue; }
            int idx = s[i] - 'A';
            if(Node[now].Next[idx]) now = Node[now].Next[idx];
            else{
                int p = Node[now].fail;
                while(p!=-1 && Node[p].Next[idx]==0) p = Node[p].fail;
                if(p == -1) now = 0;
                else now = Node[p].Next[idx];
            }
            if(Node[now].id) Get(now, arr);
        }
    }
}ac;

char S[Max_Len];
int arr[1001];
char str[1001][51];
int main(void)
{
    memset(arr, 0, sizeof(arr));
    int n;
    while(~scanf("%d", &n)){

        ac.init();
        for(int i=1; i<=n; i++){
            scanf("%s", str[i]);
            ac.insert(str[i], i);
        }
        ac.BuildFail();
        scanf("%s", S);
        ac.Match(S, arr);
        for(int i=1; i<=n; i++){
            if(arr[i]){
                printf("%s: %d
", str[i], arr[i]);
                arr[i] = 0;
            }
        }
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/qwertiLH/p/7617742.html