[cf710F]String Set Queries

考虑对于$S$内的每一个串,建立一个ac自动机

类似于线段树的结构去合并,即当有两个ac自动机所含串数量相同,就将这两个ac自动机暴力合并

删除可以再建一组ac自动机,减去即可

由于每一个串至多参与$log n$次合并,且最终ac自动机个数也为$log n$,因此总复杂度即为$o(nlog n)$

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 300005
 4 queue<int>q;
 5 queue<pair<int,int> >qq;
 6 int V,n,p,sz[2][21],r[2][21],nex[N],tot[N],sum[N],ch[N][26];
 7 char s[N];
 8 void build(int r){
 9     nex[r]=r;
10     for(int i=0;i<26;i++)
11         if (ch[r][i]){
12             q.push(ch[r][i]);
13             nex[ch[r][i]]=r;
14         }
15     while (!q.empty()){
16         int k=q.front();
17         q.pop();
18         sum[k]=tot[k]+sum[nex[k]];
19         for(int i=0;i<26;i++)
20             if (ch[k][i]){
21                 int u=nex[k];
22                 while ((u!=r)&&(!ch[u][i]))u=nex[u];
23                 if (ch[u][i])u=ch[u][i];
24                 nex[ch[k][i]]=u;
25                 q.push(ch[k][i]);
26             }
27     }
28 }
29 void add_str(int r){
30     int k=r;
31     for(int i=0;s[i];i++)k=ch[k][s[i]-'a']=++V;
32     tot[k]=1;
33     build(r);
34 }
35 int query_str(int r){
36     int k=r,ans=0;
37     for(int i=0;s[i];i++){
38         while ((k!=r)&&(!ch[k][s[i]-'a']))k=nex[k];
39         if (ch[k][s[i]-'a'])k=ch[k][s[i]-'a'];
40         ans+=sum[k];
41     }
42     return ans;
43 }
44 void merge(int r1,int r2){
45     qq.push(make_pair(r1,r2));
46     while (!qq.empty()){
47         int x=qq.front().first,y=qq.front().second;
48         qq.pop();
49         tot[x]+=tot[y];
50         for(int i=0;i<26;i++)
51             if (ch[y][i]){
52                 if (!ch[x][i])ch[x][i]=ch[y][i];
53                 else qq.push(make_pair(ch[x][i],ch[y][i]));
54             }
55     }
56     build(r1);
57 }
58 void add(int p){
59     int x=sz[p][0];
60     sz[p][++x]=1;
61     add_str(r[p][x]=++V);
62     while ((x>1)&&(sz[p][x]==sz[p][x-1])){
63         sz[p][x-1]+=sz[p][x];
64         merge(r[p][x-1],r[p][x]);
65         x--;
66     }
67     sz[p][0]=x;
68 }
69 int query(int p){
70     int ans=0;
71     for(int i=1;i<=sz[p][0];i++)ans+=query_str(r[p][i]);
72     return ans;
73 }
74 int main(){
75     scanf("%d",&n);
76     for(int i=1;i<=n;i++){
77         scanf("%d%s",&p,s);
78         if (p==1)add(0);
79         if (p==2)add(1);
80         if (p==3){
81             printf("%d
",query(0)-query(1));
82             fflush(stdout);
83         }
84     }
85 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/14142708.html