洛谷P2336 喵星球上的点名

解:SAM + 线段树合并 + DFS序。

姓和名之间插入特殊字符,转化为下题:

给定串集合S,T,问S中每个串包含了T中的几个串?T中每个串被多少个S中的串包含?

解:对S建广义SAM,并线段树合并维护每个节点有多少串。

T中每个串在S的sam上跑,如果没能跑完就被包含0次。否则答案就是到达的节点上的串数。第二问解决。

标记T中每个串最后到达的节点。S中每个串跑S的sam会得到若干个点。统计这些点到根路径的并集上的标记个数即可。

按DFS序排序,加上每个节点到根路径的贡献,减去相邻节点lca到根路径的贡献。第一问解决。

  1 #include <bits/stdc++.h>
  2 
  3 const int N = 400010, M = 10000010;
  4 
  5 struct Edge {
  6     int nex, v;
  7 }edge[N]; int tp;
  8 
  9 std::map<int, int> tr[N];
 10 int fail[N], len[N], tot = 1, e[N], n, m, siz[N], ed[N],
 11     stk[N], top, num2, pos2[N], ST[N << 1][22], pw[N << 1], d[N];
 12 int ls[M], rs[M], num, sum[M], rt[N];
 13 std::vector<int> str[N];
 14 
 15 inline void add(int x, int y) {
 16     tp++;
 17     edge[tp].v = y;
 18     edge[tp].nex = e[x];
 19     e[x] = tp;
 20     return;
 21 }
 22 
 23 inline bool cmp(const int &a, const int &b) {
 24     return pos2[a] < pos2[b];
 25 }
 26 
 27 void insert(int p, int l, int r, int &o) {
 28     if(!o) o = ++num;
 29     if(l == r) {
 30         sum[o] = 1;
 31         return;
 32     }
 33     int mid = (l + r) >> 1;
 34     if(p <= mid) insert(p, l, mid, ls[o]);
 35     else insert(p, mid + 1, r, rs[o]);
 36     sum[o] = sum[ls[o]] + sum[rs[o]];
 37     return;
 38 }
 39 
 40 int merge(int x, int y) {
 41     if(!x || !y) return x | y;
 42     int o = ++num;
 43     ls[o] = merge(ls[x], ls[y]);
 44     rs[o] = merge(rs[x], rs[y]);
 45     if(!ls[o] && !rs[o]) sum[o] = 1;
 46     else sum[o] = sum[ls[o]] + sum[rs[o]];
 47     return o;
 48 }
 49 
 50 int ask(int L, int R, int l, int r, int o) {
 51     if(!o) return 0;
 52     if(L <= l && r <= R) return sum[o];
 53     int mid = (l + r) >> 1, ans = 0;
 54     if(L <= mid) ans += ask(L, R, l, mid, ls[o]);
 55     if(mid < R) ans += ask(L, R, mid + 1, r, rs[o]);
 56     return ans;
 57 }
 58 
 59 void DFS_1(int x) {
 60     pos2[x] = ++num2;
 61     ST[num2][0] = x;
 62     for(int i = e[x]; i; i = edge[i].nex) {
 63         int y = edge[i].v;
 64         d[y] = d[x] + 1;
 65         DFS_1(y);
 66         ST[++num2][0] = x;
 67         rt[x] = merge(rt[x], rt[y]);
 68     }
 69     return;
 70 }
 71 
 72 inline void prework() {
 73     for(int i = 2; i <= num2; i++) pw[i] = pw[i >> 1] + 1;
 74     for(int j = 1; j <= pw[num2]; j++) {
 75         for(int i = 1; i + (j << 1) - 1 <= num2; i++) {
 76             if(d[ST[i][j - 1]] < d[ST[i + (1 << (j - 1))][j - 1]])
 77                 ST[i][j] = ST[i][j - 1];
 78             else
 79                 ST[i][j] = ST[i + (1 << (j - 1))][j - 1];
 80         }
 81     }
 82     return;
 83 }
 84 
 85 inline int lca(int x, int y) {
 86     x = pos2[x];
 87     y = pos2[y];
 88     if(x > y) std::swap(x, y);
 89     int t = pw[y - x + 1];
 90     if(d[ST[x][t]] < d[ST[y - (1 << t) + 1][t]])
 91         return ST[x][t];
 92     else
 93         return ST[y - (1 << t) + 1][t];
 94 }
 95 
 96 void DFS_2(int x) {
 97     siz[x] += ed[x];
 98     for(int i = e[x]; i; i = edge[i].nex) {
 99         int y = edge[i].v;
100         siz[y] = siz[x];
101         DFS_2(y);
102     }
103     return;
104 }
105 
106 inline int split(int p, int f) {
107     int Q = tr[p][f], nQ = ++tot;
108     len[nQ] = len[p] + 1;
109     fail[nQ] = fail[Q];
110     fail[Q] = nQ;
111     //memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
112     tr[nQ] = tr[Q];
113     while(tr[p][f] == Q) {
114         tr[p][f] = nQ;
115         p = fail[p];
116     }
117     return nQ;
118 }
119 
120 inline int insert(int p, int f, int id) {
121     int np;
122     if(tr[p].count(f)) {
123         int Q = tr[p][f];
124         if(len[Q] == len[p] + 1) {
125             np = Q;
126         }
127         else {
128             np = split(p, f);
129         }
130         insert(id, 1, n, rt[np]);
131         return np;
132     }
133     np = ++tot;
134     len[np] = len[p] + 1;
135     while(p && !tr[p].count(f)) {
136         tr[p][f] = np;
137         p = fail[p];
138     }
139     if(!p) {
140         fail[np] = 1;
141     }
142     else {
143         int Q = tr[p][f];
144         if(len[Q] == len[p] + 1) {
145             fail[np] = Q;
146         }
147         else {
148             fail[np] = split(p, f);
149         }
150     }
151     insert(id, 1, n, rt[np]);
152     return np;
153 }
154 
155 void out(int l, int r, int o) {
156     if(!o) return;
157     if(l == r) {
158         printf("%d ", r);
159         return;
160     }
161     int mid = (l + r) >> 1;
162     out(l, mid, ls[o]);
163     out(mid + 1, r, rs[o]);
164     return;
165 }
166 
167 inline void clear() {
168     for(int i = 1; i <= tot; i++) {
169         e[i] = len[i] = fail[i] = rt[i] = 0;
170         tr[i].clear();
171     }
172     for(int i = 1; i <= num; i++) {
173         ls[i] = rs[i] = sum[i] = 0;
174     }
175     tp = num = 0;
176     tot = 1;
177     return;
178 }
179 
180 int main() {
181     scanf("%d%d", &n, &m);
182     for(int i = 1; i <= n; i++) {
183         int k, x, p = 1;
184         scanf("%d", &k);
185         for(int j = 1; j <= k; j++) {
186             scanf("%d", &x);
187             str[i].push_back(x);
188             p = insert(p, x, i);
189         }
190         str[i].push_back(-1);
191         p = insert(p, -1, i);
192         scanf("%d", &k);
193         for(int j = 1; j <= k; j++) {
194             scanf("%d", &x);
195             str[i].push_back(x);
196             p = insert(p, x, i);
197         }
198     }
199     /// build
200     for(int i = 2; i <= tot; i++) {
201         //printf("add %d %d 
", fail[i], i);
202         add(fail[i], i);
203     }
204     DFS_1(1);
205 
206     for(int i = 1; i <= m; i++) {
207         int k, x, p = 1, fd = 0, ans = 0;
208         scanf("%d", &k);
209         for(int j = 1; j <= k; j++) {
210             scanf("%d", &x);
211             if(!tr[p].count(x)) fd = 1;
212             else p = tr[p][x];
213         }
214         if(!fd) {
215             ans = sum[rt[p]];
216             ed[p]++;
217         }
218         printf("%d
", ans);
219     }
220 
221     DFS_2(1);
222     prework();
223 
224     for(int i = 1; i <= n; i++) {
225         int p = 1; top = 0;
226         for(int j = 0; j < (int)str[i].size(); j++) {
227             int x = str[i][j];
228             p = tr[p][x];
229             stk[++top] = p;
230         }
231         std::sort(stk + 1, stk + top + 1,cmp);
232         top = std::unique(stk + 1, stk + top + 1) - stk - 1;
233         int ans = 0;
234         for(int j = 1; j <= top; j++) {
235             ans += siz[stk[j]];
236             if(j < top) ans -= siz[lca(stk[j], stk[j + 1])];
237         }
238         printf("%d ", ans);
239     }
240 
241     return 0;
242 }
AC代码

AC自动机解法

原文地址:https://www.cnblogs.com/huyufeifei/p/10496912.html