hdoj2896(AC自动机简单题)

题目链接:https://vjudge.net/problem/HDU-2896

题意:给出n个模式串(没有相同的模式串),模式串总长<=1e5。然后给出m个文本,文本总长<=1e7,求每个文本串中出现的模式串(最多3种)。


思路:

  板子题。因为要输出文本串中出现的模式串编号,所以需要记录字典树中以每个结点为末字符的模式串编号(key数组)。trie数组的第二位开到了130(所有ascii码都有),因此输入字符串用到了scanf("%[^ ]",s),但它不吃换行符,要用getchar吃掉换行符。

  又因为每输入一个文本,查询时会破坏key数组,因此创建AC自动机时用key1来记录,每次用key1来更新key2,用key2去查询。

AC code:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;

const int maxn=1e5+5;
const int maxs=1e4+5;
int n,m,cnt,ans,trie[maxn][130],key1[maxn],key2[maxn],fail[maxn];
int res[5],cnt1;
char s[maxs];

void build(char *s,int k){
    int len=strlen(s),u=0;
    for(int i=0;i<len;++i){
        int t=s[i];
        if(!trie[u][t]){
            ++cnt;
            memset(trie[cnt],0,sizeof(trie[cnt]));
            key1[cnt]=0,fail[0]=0;
            trie[u][t]=cnt;
        }
        u=trie[u][t];
    }
    key1[u]=k;
}

void get_fail(){
    queue<int> que;
    for(int i=0;i<130;++i){
        if(trie[0][i]){
            fail[trie[0][i]]=0;
            que.push(trie[0][i]);
        }
    }
    while(!que.empty()){
        int u=que.front();que.pop();
        for(int i=0;i<130;++i){
            if(trie[u][i]){
                fail[trie[u][i]]=trie[fail[u]][i];
                que.push(trie[u][i]);
            }
            else{
                trie[u][i]=trie[fail[u]][i];
            }
        }
    }
}

void query(char *s){
    int len=strlen(s),u=0;
    for(int i=0;i<len;++i){
        int t=s[i];
        u=trie[u][t];
        for(int j=u;j&&key2[j]!=-1;j=fail[j]){
            if(key2[j]) res[++cnt1]=key2[j];
            key2[j]=-1;
        }
    }
}

int main(){
    scanf("%d",&n);
    getchar();
    cnt=0,ans=0;
    memset(trie[0],0,sizeof(trie[0]));
    key1[0]=0;
    for(int i=1;i<=n;++i){
        scanf("%[^
]",s);
        getchar();
        build(s,i);
    }
    fail[0]=0;
    get_fail();
    scanf("%d",&m);
    getchar();
    for(int i=1;i<=m;++i){
        for(int j=0;j<=cnt;++j)
            key2[j]=key1[j];
        scanf("%[^
]",s);
        getchar();
        cnt1=0;
        query(s);
        if(cnt1){
            ++ans;
            sort(res+1,res+cnt1+1);
            printf("web %d:",i);
            for(int j=1;j<=cnt1;++j)
                printf(" %d",res[j]);
            printf("
");
        }
    }
    printf("total: %d
",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/FrankChen831X/p/12491355.html