AC自动机代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

const int maxn=5e5+10;

struct node
{
    int en;
    int vis[26];
    int fail;
    node(){fail=0;en=0;for(int i=0;i<26;i++)vis[i]=0;}
}trie[maxn];

class ac
{
    int cnt;
public:
    ac():cnt(0){}
    void init(){cnt=0;clear();}
    void clear()
    {
        for(int i=0;i<26;i++)trie[cnt].vis[i]=0;
        trie[cnt].en=trie[cnt].fail=0;
    }
    void build(char str[])
    {
        int len=strlen(str),ch,now=0;
        for(int i=0;i<len;i++)
        {
            ch=str[i]-'a';
            if(!trie[now].vis[ch])
            {
                trie[now].vis[ch]=++cnt;
                clear();
            }
            now=trie[now].vis[ch];
        }
        trie[now].en++;
    }
    void get_fail()
    {
        queue<int> q;
        int u;
        for(int i=0;i<26;i++)
        {
            if(trie[0].vis[i])
            {
                trie[trie[0].vis[i]].fail=0;
                q.push(trie[0].vis[i]);
            }
        }
        while(!q.empty())
        {
            u=q.front();
            q.pop();
            for(int i=0;i<26;i++)
            {
                if(trie[u].vis[i])
                {
                    trie[trie[u].vis[i]].fail=trie[trie[u].fail].vis[i];
                    q.push(trie[u].vis[i]);
                }
                else
                {
                    trie[u].vis[i]=trie[trie[u].fail].vis[i];
                }
            }
        }
    }
    int solve(char str[])
    {
        int len=strlen(str),now=0,ch,ans=0,tmp;
        for(int i=0;i<len;i++)
        {
            ch=str[i]-'a';
            now=trie[now].vis[ch];
            tmp=now;
            while(tmp&&trie[tmp].en!=-1)
            {
                ans+=trie[tmp].en;
                trie[tmp].en=-1;
                tmp=trie[tmp].fail;
            }
        }
        return ans;
    }
};

char str[2*maxn];
ac acm;

int main()
{
    int n,t;
    scanf("%d",&t);
    while(t--)
    {
        acm.init();
        scanf("%d",&n);
        for(int i=0;i<n;i++)
        {
            scanf("%s",str);
            acm.build(str);
        }
        acm.get_fail();
        scanf("%s",str);
        printf("%d
",acm.solve(str));
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/wyhbadly/p/11276897.html