AC自动机是著名的多模匹配算法之一。常见的例子就是给出n个单词,再给你包含m个字符的文章,问你有多少个单词在文章中出现过。
其实AC自动机是以字典树和KMP的基础上实现的。
首先要构造一个Tire,然后再在上面构造失配然后再匹配。
失配(fail)指针:使当前字符失配时跳转到具有最长公共前后缀的字符继续匹配。
但是我现在还是不怎么懂失配是怎么弄的,先留个坑,后面来填。
(8.14补)今天听了实验室的人讲的AC自动机,感觉理解也深了。
其实就是从最长的那个串按着失配指针一只往前跳,我们知道失配指针是指向当前字符串的最长后缀的,那么我们不会漏下任何一个。
所以跟着失配指针我们可以找到所有能匹配的个数(因为建树会在节点标记他是否是字符串结尾)。
模板:
/* gyt Live up to every day */ #include<cstdio> #include<cmath> #include<iostream> #include<algorithm> #include<vector> #include<stack> #include<cstring>` #include<queue> #include<set> #include<string> #include<map> #include <time.h> #define PI acos(-1) using namespace std; typedef long long ll; typedef double db; const int maxn = 10000+5; const ll maxm = 1e7; const ll mod = 1e9 + 7; const int INF = 0x3f3f3f; const ll inf = 1e15 + 5; const db eps = 1e-9; const int kind=26; struct node{ node *fail; node *next[kind]; int coun; void nodee() { fail=NULL; coun=0; for (int i=0; i<kind; i++) next[i]=NULL; } }*root; char str[1000000+100]; void updata() { node *p=root; int len=strlen(str); for (int i=0; i<len; i++) { int pos=str[i]-'a'; if (p->next[pos]==NULL) { p->next[pos]=new node; p->next[pos]->nodee(); p=p->next[pos]; } else p=p->next[pos]; } p->coun++; // cout<<p->coun<<endl; } void getfail() { node *p=root, *son, *tmp; queue<struct node*>que; que.push(p); while(!que.empty()) { tmp=que.front(); que.pop(); for (int i=0; i<26; i++) { son=tmp->next[i]; if (son!=NULL) { if (tmp==root) { son->fail=root; } else { p=tmp->fail; while(p) { if (p->next[i]) { son->fail=p->next[i]; break; } p=p->fail; } if (!p) son->fail=root; } que.push(son); } } } } void query() { int len=strlen(str); node *p, *tmp; p=root; int cnt=0; for (int i=0; i<len; i++) { int pos=str[i]-'a'; while(!p->next[pos]&& p!=root) p=p->fail; p=p->next[pos]; if (!p) p=root; tmp=p; while(tmp!=root) { if (tmp->coun>=0) { cnt+=tmp->coun; tmp->coun=-1; } else break; tmp=tmp->fail; } //cout<<cnt<<endl; } printf("%d ", cnt); } void solve() { root=new node; root->nodee(); root->fail=NULL; int n; scanf("%d", &n); getchar(); for (int i=0; i<n; i++) { gets(str); updata(); } getfail(); gets(str); query(); } int main() { int t = 1; // freopen("in.txt", "r", stdin); scanf("%d", &t); while(t--) solve(); return 0; }
如果要求在字符串里出现的次数,我们再在结构体里定义一个id,表示编号,query里面在变变
void query(char *str) { int len=strlen(str); node *p, *tmp; p=root; int cntt=0; for (int i=0; i<len; i++) { int pos=str[i]-'A'; while(!p->next[pos]&& p!=root) p=p->fail; p=p->next[pos]; if (!p) p=root; tmp=p; while(tmp!=root && tmp->id>0) { //cout<<1<<endl; int aa=tmp->id; // cout<<aa<<endl; cnt[aa]++; tmp=tmp->fail; } } }
#include<cstdio> #include<cmath> #include<iostream> #include<algorithm> #include<vector> #include<stack> #include<cstring>` #include<queue> #include<set> #include<string> #include<map> #include <time.h> #define PI acos(-1) #include<cstdio> #include<cstring> using namespace std; const int kind = 26; //有多少种字符 const int N = 1005; const int M = 2000005; struct node { node *next[kind]; node *fail; int id;//病毒编号 node() { for(int i = 0; i < kind; i++) next[i] = NULL; fail = NULL; id = 0; } }*q[51*N]; node *root; int head, tail; char source[M], s[M]; char vir[N][55]; int cnt[N]; void updata(char *str, int id) { node *p=root; int len=strlen(str); for (int i=0; i<len; i++) { int pos=str[i]-'A'; if (p->next[pos]==NULL) { p->next[pos]=new node(); //p->next[pos]->nodee(); p=p->next[pos]; } else p=p->next[pos]; } // p->coun++; p->id = id; // cout<<p->coun<<endl; } void getfail() { node *p=root, *son, *tmp; queue<struct node*>que; que.push(p); while(!que.empty()) { tmp=que.front(); que.pop(); for (int i=0; i<26; i++) { son=tmp->next[i]; if (son!=NULL) { if (tmp==root) { son->fail=root; } else { p=tmp->fail; while(p) { if (p->next[i]) { son->fail=p->next[i]; break; } p=p->fail; } if (!p) son->fail=root; } que.push(son); } } } } void query(char *str) { int len=strlen(str); node *p, *tmp; p=root; int cntt=0; for (int i=0; i<len; i++) { int pos=str[i]-'A'; while(!p->next[pos]&& p!=root) p=p->fail; p=p->next[pos]; if (!p) p=root; tmp=p; while(tmp!=root && tmp->id>0) { //cout<<1<<endl; int aa=tmp->id; // cout<<aa<<endl; cnt[aa]++; tmp=tmp->fail; } } } int main() { int n; // freopen("in.txt", "r", stdin); while(~scanf("%d",&n)) { memset(cnt, 0, sizeof(cnt)); head = tail = 0; root = new node(); for(int i = 1; i <= n; i++) { scanf("%s", vir[i]); updata(vir[i], i); } getfail(); scanf("%s",source); int len = strlen(source); int l = 0; for(int i = 0; i <= len; i++) { if(source[i] >= 'A' && source[i] <= 'Z') s[l++] = source[i]; else { s[l] = '