P3808 【模板】AC自动机(简单版)

题目背景

这是一道简单的AC自动机模板题。

用于检测正确性以及算法常数。

为了防止卡OJ,在保证正确的基础上只有两组数据,请不要恶意提交。

管理员提示:本题数据内有重复的单词,且重复单词应该计算多次,请各位注意

题目描述

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

输入输出格式

输入格式:

第一行一个n,表示模式串个数;

下面n行每行一个模式串;

下面一行一个文本串。

输出格式:

一个数表示答案

输入输出样例

输入样例#1: 
2
a
aa
aa
输出样例#1: 
2

说明

subtask1[50pts]:∑length(模式串)<=10^6,length(文本串)<=10^6,n=1;

subtask2[50pts]:∑length(模式串)<=10^6,length(文本串)<=10^6;

Solution:

  AC自动机板子。

  简单讲一下,对于多模式串的匹配,如果都用kmp解决则主串要扫多次,复杂度直接变为$O(n^2)$。

  而AC自动机可以只扫一次主串进行多模式串匹配,先将所有模式串构建出一棵trie树,然后类比kmp我们在trie树上构建失配边,每次失配直接跳向失配边所指指针继续匹配,这样就能将匹配的时间复杂度变为$O(sum{|P|}+|T|)$,至于求失配边我们可以类比kmp的next数组求法,将递归改为在trie树上bfs递推,递推很好理解直接见代码。

  解释一下fail数组:根节点到$fail[i]$表示的字符串,是根节点到$i$表示的字符串的最长后缀。

  这样就能保证到了$i$节点时,根节点到$fail[i]$所表示的字符串一定出现过,失配时就能接上去继续匹配。

  小技巧:

  1、在AC自动机的构建时,往往会有一个类似路径压缩的优化:当不存在$trie[p][i]$节点时,直接$trie[p][i]=trie[fail[p]][i]$,即将该空节点指向其失配边所指节点的该字符节点。

  2、匹配时往往会多次走失配边,影响效率,于是另起一个$last[i]$表示$i$节点所指向的失配边下一个出现的字符串结尾(专业术语叫后缀链接——suffix link),实现时直接在bfs求解fail数组时递推就好了。

  对于本题,我们直接建好AC自动机,然后愉快的匹配,每次都累加一下失配边所指的完整字符串个数就好了,注意一下每个模式串重复出现只能算一次,所以累加完后要清除end标记。然后的话本题数据比较水,last数组优化并不明显,但是在加强版的模板题中优化效果还是可以的。(话说AC自动机并不难,以前咋不会呢?用下心体会,就好了!)

代码:

#include<bits/stdc++.h>
#define il inline
#define ll long long
#define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++)
#define Bor(i,a,b) for(int (i)=(b);(i)>=(a);(i)--)
using namespace std;
const int N=500005;
int n,trie[N][26],tot,end[N],fail[N],last[N];
char s[N<<1];

il void insert(char *s){
    int len=strlen(s)-1,p=0,x;
    For(i,0,len){
        x=s[i]-'a';
        if(!trie[p][x])trie[p][x]=++tot;
        p=trie[p][x];
    }
    end[p]++;
}

il void bfs(){
    queue<int>q;
    For(i,0,25) if(trie[0][i]) fail[trie[0][i]]=0,q.push(trie[0][i]);
    while(!q.empty()){
        int u=q.front();q.pop();
        For(i,0,25){
            int v=trie[u][i];
            if(v) fail[v]=trie[fail[u]][i],last[v]=end[fail[v]]?fail[v]:last[fail[v]],q.push(v);
            else trie[u][i]=trie[fail[u]][i];
        }
    }
}

il int find(char *s){
    int ans=0,len=strlen(s)-1,p=0;
    For(i,0,len){
        p=trie[p][s[i]-'a'];
        for(int j=p;j&&end[j]!=-1;j=last[j]) ans+=end[j],end[j]=-1;
    }
    return ans;
}

int main(){
    scanf("%d",&n);
    For(i,1,n) scanf("%s",s),insert(s);
    bfs();
    scanf("%s",s);
    cout<<find(s);
    return 0;
}
原文地址:https://www.cnblogs.com/five20/p/9443656.html