HDU 2243 考研路茫茫――单词情结 ——(AC自动机+矩阵快速幂)

  和前几天做的AC自动机类似。

  思路简单但是代码200余行。。

  假设solve_sub(i)表示长度为i的不含危险单词的总数。

  最终答案为用总数(26^1+26^2+...+26^n)减去(solve_sub(1)+solve_sub(2)+...+solve_sub(n))。前者构造f[i]=f[i-1]*26+26然后矩阵快速幂即可(当然也可以分治的方法)。后者即构造出dp矩阵p,然后计算(p^1+p^2+...+p^n),对其分治即可。

  代码如下:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 #include <string.h>
  4 #include <vector>
  5 #include <queue>
  6 #include <iostream>
  7 using namespace std;
  8 const int MAX_Tot = 30 + 5;
  9 const int mod = 100000;
 10 typedef unsigned long long ull;
 11 
 12 int m,n;
 13 
 14 struct matrix
 15 {
 16     ull e[MAX_Tot][MAX_Tot];
 17     int n,m;
 18     matrix() {}
 19     matrix(int _n,int _m): n(_n),m(_m) {memset(e,0,sizeof(e));}
 20     matrix operator * (const matrix &temp)const
 21     {
 22         matrix ret = matrix(n,temp.m);
 23         for(int i=1;i<=ret.n;i++)
 24         {
 25             for(int j=1;j<=ret.m;j++)
 26             {
 27                 for(int k=1;k<=m;k++)
 28                 {
 29                     ret.e[i][j] += e[i][k]*temp.e[k][j];
 30                 }
 31             }
 32         }
 33         return ret;
 34     }
 35     matrix operator + (const matrix &temp)const
 36     {
 37         matrix ret = matrix(n,m);
 38         for(int i=1;i<=n;i++)
 39         {
 40             for(int j=1;j<=m;j++)
 41             {
 42                 ret.e[i][j] += e[i][j]+temp.e[i][j];
 43             }
 44         }
 45         return ret;
 46     }
 47     void getE()
 48     {
 49         for(int i=1;i<=n;i++)
 50         {
 51             for(int j=1;j<=m;j++)
 52             {
 53                 e[i][j] = i==j?1:0;
 54             }
 55         }
 56     }
 57 };
 58 
 59 matrix qpow(matrix temp,int x)
 60 {
 61     int sz = temp.n;
 62     matrix base = matrix(sz,sz);
 63     base.getE();
 64     while(x)
 65     {
 66         if(x & 1) base = base * temp;
 67         x >>= 1;
 68         temp = temp * temp;
 69     }
 70     return base;
 71 }
 72 
 73 matrix solve(matrix a, int k)
 74 {
 75     if(k == 1) return a;
 76     int n = a.n;
 77     matrix temp = matrix(n,n);
 78     temp.getE();
 79     if(k & 1)
 80     {
 81         matrix ex = qpow(a,k);
 82         k--;
 83         temp = temp + qpow(a,k/2);
 84         return temp * solve(a,k/2) + ex;
 85     }
 86     else
 87     {
 88         temp = temp + qpow(a,k/2);
 89         return temp * solve(a,k/2);
 90     }
 91 }
 92 
 93 struct Aho
 94 {
 95     struct state
 96     {
 97         int nxt[26];
 98         int fail,cnt;
 99     }stateTable[MAX_Tot];
100 
101     int size;
102 
103     queue<int> que;
104 
105     void init()
106     {
107         while(que.size()) que.pop();
108         for(int i=0;i<MAX_Tot;i++)
109         {
110             memset(stateTable[i].nxt,0,sizeof(stateTable[i].nxt));
111             stateTable[i].fail = stateTable[i].cnt = 0;
112         }
113         size = 1;
114     }
115 
116     void insert(char *s)
117     {
118         int n = strlen(s);
119         int now = 0;
120         for(int i=0;i<n;i++)
121         {
122             char c = s[i];
123             if(!stateTable[now].nxt[c-'a'])
124                 stateTable[now].nxt[c-'a'] = size++;
125             now = stateTable[now].nxt[c-'a'];
126         }
127         stateTable[now].cnt = 1;
128     }
129 
130     void build()
131     {
132         stateTable[0].fail = -1;
133         que.push(0);
134 
135         while(que.size())
136         {
137             int u = que.front();que.pop();
138             for(int i=0;i<26;i++)
139             {
140                 if(stateTable[u].nxt[i])
141                 {
142                     if(u == 0) stateTable[stateTable[u].nxt[i]].fail = 0;
143                     else
144                     {
145                         int v = stateTable[u].fail;
146                         while(v != -1)
147                         {
148                             if(stateTable[v].nxt[i])
149                             {
150                                 stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i];
151                                 // 在匹配fail指针的时候顺便更新cnt
152                                 if(stateTable[stateTable[stateTable[u].nxt[i]].fail].cnt == 1)
153                                     stateTable[stateTable[u].nxt[i]].cnt = 1;
154                                 break;
155                             }
156                             v = stateTable[v].fail;
157                         }
158                         if(v == -1) stateTable[stateTable[u].nxt[i]].fail = 0;
159                     }
160                     que.push(stateTable[u].nxt[i]);
161                 }
162                 /*****建立自动机nxt指针*****/
163                 else
164                 {
165                     if(u == 0) stateTable[u].nxt[i] = 0;
166                     else
167                     {
168                         int p = stateTable[u].fail;
169                         while(p != -1 && stateTable[p].nxt[i] == 0) p = stateTable[p].fail;
170                         if(p == -1) stateTable[u].nxt[i] = 0;
171                         else stateTable[u].nxt[i] = stateTable[p].nxt[i];
172                     }
173                 }
174                 /*****建立自动机nxt指针*****/
175             }
176         }
177     }
178 
179     matrix build_matrix()
180     {
181         matrix ans = matrix(size,size);
182         for(int i=0;i<size;i++)
183         {
184             for(int j=0;j<26;j++)
185             {
186                 if(!stateTable[i].cnt && !stateTable[stateTable[i].nxt[j]].cnt)
187                     ans.e[i+1][stateTable[i].nxt[j]+1]++;
188             }
189         }
190         return ans;
191     }
192 }aho;
193 
194 void print(matrix p)
195 {
196     int n = p.n;
197     int m = p.m;
198     for(int i=1;i<=n;i++)
199     {
200         for(int j=1;j<=m;j++)
201         {
202             printf("%d ",p.e[i][j]);
203         }
204         puts("");
205     }
206 }
207 
208 int main()
209 {
210     while(scanf("%d%d",&m,&n) == 2)
211     {
212         aho.init();
213         char s[15];
214         for(int i=1;i<=m;i++)
215         {
216             scanf("%s",s);
217             aho.insert(s);
218         }
219         aho.build();
220         matrix p = aho.build_matrix();
221         p = solve(p,n);
222         ull temp = 0;
223         for(int i=1;i<=aho.size;i++) temp += p.e[1][i];
224         matrix t = matrix(1,2);
225         t.e[1][2] = 1;
226         matrix A = matrix(2,2);
227         A.e[1][1] = A.e[2][1] = 26; A.e[2][2] = 1;
228         t = t * qpow(A,n);
229         ull ans = t.e[1][1] - temp;
230         printf("%llu
",ans);
231     }
232     return 0;
233 }

  最后觉得,,我之前矩阵模板里的print()真好用啊233= =。

原文地址:https://www.cnblogs.com/zzyDS/p/6512469.html