AC自动机处理多串匹配——cf1202E

si+sj中间有一个切割点,我们在t上枚举这个切割点i,即以t[i]作为最后一个字符时求有多少si可以匹配,以t[i+1]作为第一个字符时有多少sj可以匹配

那么对s串正着建一个ac自动机,反着建一个自动机,然后t正反各匹配一次,用sum[]数组记录t[i]作为最后一个字符可以匹配的串数量

注意:求sum数组时,暴力跳fail显然会t,考虑到跳fail是为了统计匹配串的后缀,那么我们在build时,就可以在处理fail指针时就可以把那个fail的end加到now的end上去,这样就避免了暴力跳fail

#include<bits/stdc++.h>
using namespace std;
#define N 200005

struct Trie{
    int nxt[N][26],fail[N],end[N];
    int root,L;
    int newnode(){
        memset(nxt[L],-1,sizeof nxt[L]);
        end[L]=0;
        return L++;
    }
    void init(){
        L++;
        root=newnode();
    }
    void insert(char buf[]){
        int len=strlen(buf+1);
        int now=root;
        for(int i=1;i<=len;i++){
            if(nxt[now][buf[i]-'a']==-1)
                nxt[now][buf[i]-'a']=newnode();
            now=nxt[now][buf[i]-'a'];
        }
        end[now]++;
    }
    void build(){
        queue<int>q;
        fail[root]=root;
        for(int i=0;i<26;i++)
            if(nxt[root][i]==-1)
                nxt[root][i]=root;
            else {
                fail[nxt[root][i]]=root;
                q.push(nxt[root][i]);
            }
        while(q.size()){
            int now=q.front();
            q.pop(); 
            for(int i=0;i<26;i++)
                if(nxt[now][i]==-1)
                    nxt[now][i]=nxt[fail[now]][i];
                else {
                    fail[nxt[now][i]]=nxt[fail[now]][i];
                    end[nxt[now][i]]+=end[nxt[fail[now]][i]];
                    q.push(nxt[now][i]);
                }
        }
    }
    int sum[N];
    int query(char buf[]){
        int len=strlen(buf+1);
        int now=root;
        for(int i=1;i<=len;i++){
            now=nxt[now][buf[i]-'a'];
            sum[i]+=end[now];
        }
    }
}; 

char buf[N],t[N];
Trie t1,t2;
int n;
void reserve(char s[]){
    int i=1,j=strlen(s+1);
    while(i<j){
        swap(s[i],s[j]);
        ++i,--j;
    }
}

int main(){
    t1.init();
    t2.init();    
    scanf("%s%d",t+1,&n);
    for(int i=1;i<=n;i++){
        scanf("%s",buf+1);
        t1.insert(buf);
        reserve(buf);
        t2.insert(buf);
    }
    t1.build();
    t2.build();
    t1.query(t);
    reserve(t);
    t2.query(t);
    
    int len=strlen(t+1);
    long long ans=0;
    for(int i=0;i<len;i++)
        ans+=(long long)t1.sum[i]*t2.sum[len-i];
    cout<<ans<<'
';
}
原文地址:https://www.cnblogs.com/zsben991126/p/11518657.html