HDU 2222 Keywords Search(AC自动机入门)

  题意:给出若干个单词和一段文本,问有多少个单词出现在其中。如果两个单词是相同的,得算两个单词的贡献。

  分析:直接就是AC自动机的模板了。

  具体见代码:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 #include <string.h>
  4 #include <queue>
  5 using namespace std;
  6 const int MAX_N = 1000000 + 50;
  7 const int MAX_Tot = 500000 + 50;
  8 
  9 struct Aho
 10 {
 11     struct state
 12     {
 13         int nxt[26];
 14         int fail,cnt;
 15     }stateTable[MAX_Tot];
 16 
 17     int size;
 18 
 19     queue<int> que;
 20 
 21     void init()
 22     {
 23         while(que.size()) que.pop();
 24         for(int i=0;i<MAX_Tot;i++)
 25         {
 26             memset(stateTable[i].nxt,0,sizeof(stateTable[i].nxt));
 27             stateTable[i].fail = stateTable[i].cnt = 0;
 28         }
 29         size = 1;
 30     }
 31 
 32     void insert(char *s)
 33     {
 34         int n = strlen(s);
 35         int now = 0;
 36         for(int i=0;i<n;i++)
 37         {
 38             char c = s[i];
 39             if(!stateTable[now].nxt[c-'a'])
 40                 stateTable[now].nxt[c-'a'] = size++;
 41             now = stateTable[now].nxt[c-'a'];
 42         }
 43         stateTable[now].cnt++;
 44     }
 45 
 46     void build()
 47     {
 48         stateTable[0].fail = -1;
 49         que.push(0);
 50 
 51         while(que.size())
 52         {
 53             int u = que.front();que.pop();
 54             for(int i=0;i<26;i++)
 55             {
 56                 if(stateTable[u].nxt[i])
 57                 {
 58                     if(u == 0) stateTable[stateTable[u].nxt[i]].fail = 0;
 59                     else
 60                     {
 61                         int v = stateTable[u].fail;
 62                         while(v != -1)
 63                         {
 64                             if(stateTable[v].nxt[i])
 65                             {
 66                                 stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
 67                                 break;
 68                             }
 69                             v = stateTable[v].fail;
 70                         }
 71                         if(v == -1) stateTable[stateTable[u].nxt[i]].fail = 0;
 72                     }
 73                     que.push(stateTable[u].nxt[i]);
 74                 }
 75             }
 76         }
 77     }
 78 
 79     int Get(int u)
 80     {
 81         int res = 0;
 82         while(u)
 83         {
 84             res += stateTable[u].cnt;
 85             stateTable[u].cnt = 0; //这里可做拓展
 86             /*
 87                 这里清空的原因如下:
 88                 例如单词是he,需要匹配的文本是hehe
 89                 那么,我们在匹配完第一个he以后,会再次匹配he,
 90                 如果不清空,那么he这个单词就又被匹配了一遍。
 91                 10
 92                 1
 93                 abcab
 94                 abcabcab
 95                 对上面这个数据,如果不注释掉上面的语句,答案是1.
 96                 否则,答案是2.
 97                 10
 98                 2
 99                 abcab
100                 abcabcab
101                 abcabcabcab
102             */
103             u = stateTable[u].fail;
104         }
105         return res;
106     }
107 
108     int match(char *s)
109     {
110         int n = strlen(s);
111         int res = 0, now = 0;
112         for(int i=0;i<n;i++)
113         {
114             char c = s[i];
115             if(stateTable[now].nxt[c-'a']) now = stateTable[now].nxt[c-'a'];
116             else
117             {
118                 int p = stateTable[now].fail;
119                 while(p != -1 && stateTable[p].nxt[c-'a'] == 0) p = stateTable[p].fail;
120                 if(p == -1) now = 0;
121                 else now = stateTable[p].nxt[c-'a'];
122             }
123             if(stateTable[now].cnt) res += Get(now);
124         }
125         return res;
126     }
127 }aho;
128 
129 int T,n;
130 char s[MAX_N];
131 
132 int main()
133 {
134     int T;scanf("%d",&T);
135     while(T--)
136     {
137         aho.init();
138         scanf("%d",&n);
139         for(int i=1;i<=n;i++)
140         {
141             scanf("%s",s);
142             aho.insert(s);
143         }
144         aho.build();
145         scanf("%s",s);
146         printf("%d
",aho.match(s));
147     }
148 }

  顺便注意上面注释中的拓展点。

————————————————————————————————————————————————

  发现上面的代码有问题(虽然能AC),正确代码如下:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 #include <string.h>
  4 #include <queue>
  5 using namespace std;
  6 const int MAX_N = 1000000 + 50;
  7 const int MAX_Tot = 500000 + 50;
  8 
  9 struct Aho
 10 {
 11     struct state
 12     {
 13         int nxt[26];
 14         int fail,cnt;
 15     }stateTable[MAX_Tot];
 16 
 17     int size;
 18 
 19     queue<int> que;
 20 
 21     void init()
 22     {
 23         while(que.size()) que.pop();
 24         for(int i=0;i<MAX_Tot;i++)
 25         {
 26             memset(stateTable[i].nxt,0,sizeof(stateTable[i].nxt));
 27             stateTable[i].fail = stateTable[i].cnt = 0;
 28         }
 29         size = 1;
 30     }
 31 
 32     void insert(char *s)
 33     {
 34         int n = strlen(s);
 35         int now = 0;
 36         for(int i=0;i<n;i++)
 37         {
 38             char c = s[i];
 39             if(!stateTable[now].nxt[c-'a'])
 40                 stateTable[now].nxt[c-'a'] = size++;
 41             now = stateTable[now].nxt[c-'a'];
 42         }
 43         stateTable[now].cnt++;
 44     }
 45 
 46     void build()
 47     {
 48         stateTable[0].fail = -1;
 49         que.push(0);
 50 
 51         while(que.size())
 52         {
 53             int u = que.front();que.pop();
 54             for(int i=0;i<26;i++)
 55             {
 56                 if(stateTable[u].nxt[i])
 57                 {
 58                     if(u == 0) stateTable[stateTable[u].nxt[i]].fail = 0;
 59                     else
 60                     {
 61                         int v = stateTable[u].fail;
 62                         while(v != -1)
 63                         {
 64                             if(stateTable[v].nxt[i])
 65                             {
 66                                 stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
 67                                 break;
 68                             }
 69                             v = stateTable[v].fail;
 70                         }
 71                         if(v == -1) stateTable[stateTable[u].nxt[i]].fail = 0;
 72                     }
 73                     que.push(stateTable[u].nxt[i]);
 74                 }
 75             }
 76         }
 77     }
 78 
 79     int Get(int u)
 80     {
 81         int res = 0;
 82         while(u)
 83         {
 84             if(stateTable[u].cnt == -1) break;
 85             res += stateTable[u].cnt;
 86             stateTable[u].cnt = -1;
 87             u = stateTable[u].fail;
 88         }
 89         return res;
 90     }
 91 
 92     int match(char *s)
 93     {
 94         int n = strlen(s);
 95         int res = 0, now = 0;
 96         for(int i=0;i<n;i++)
 97         {
 98             char c = s[i];
 99             if(stateTable[now].nxt[c-'a']) now = stateTable[now].nxt[c-'a'];
100             else
101             {
102                 int p = stateTable[now].fail;
103                 while(p != -1 && stateTable[p].nxt[c-'a'] == 0) p = stateTable[p].fail;
104                 if(p == -1) now = 0;
105                 else now = stateTable[p].nxt[c-'a'];
106             }
107             //if(stateTable[now].cnt)
108             res += Get(now);
109         }
110         return res;
111     }
112 }aho;
113 
114 int T,n;
115 char s[MAX_N];
116 
117 int main()
118 {
119     int T;scanf("%d",&T);
120     while(T--)
121     {
122         aho.init();
123         scanf("%d",&n);
124         for(int i=1;i<=n;i++)
125         {
126             scanf("%s",s);
127             aho.insert(s);
128         }
129         aho.build();
130         scanf("%s",s);
131         printf("%d
",aho.match(s));
132     }
133 }
View Code
原文地址:https://www.cnblogs.com/zzyDS/p/5975428.html