LOJ#2720 你的名字

题意:给定母串s和若干个询问。每个询问是一个串t和两个数l,r,表示求t中有多少个本质不同的子串没有在s[l,r]中出现过。

解:我写的并不是正解......是个毒瘤做法。只在loj上面卡时过了就写loj的题号好了...

首先有个68分部分分是l = 1, r = |s|,这个怎么做呢?

回忆起之前写的广义SAM的套路,我们建出广义SAM之后把s的所有子串标记。

然后对于每个t跑一遍SAM,跳fail的时候如果该节点被标记了就停止。这样走到的节点所代表的子串总数就是该串的答案。

68分还是比较友善的...

然后顺着这个思路思考正解。

多了个母串范围限制,又听说这个题是线段树合并,这就很自然了。

对于每个节点用值域线段树维护right集合,然后对于每一个询问,查看该节点是否有个right存在于[l,r]之间。存在就GG了,同时也不能往上传递。否则可以加上这个节点的贡献。有一种居中的情况就是一个节点所表示的串中,有些可选,有些不行。这种情况下不用向上传递(然而我当时没想到,传了...无伤大雅)

细节上就是找到那一段暧昧区域的最右边一个right,用线段树上找第k个实现。

那么我们要一个一个询问的处理吗?我很天真的以为线段树合并之后下面的线段树就不存在了....于是只能把询问一次性处理。于是我又SB的对询问开了个线段树,继续线段树合并......

具体来说,对于每个询问串,都在对应节点加上该串的询问。然后DFSfail树,在每一个节点遍历询问线段树,处理询问,然后向上合并。如果不会有贡献就删去这个节点,减少复杂度。

还有个小问题,right线段树合并的时候sum是两棵树的sum和,但是询问的那棵线段树合并要去重,所以不能把sum累加...然后发现询问线段树不需要查sum......这样就可以了。

因为加了很多常数优化所以代码不是很能看......复杂度分析也不会...反正估计也是不对的。uoj洛谷都T了。bzoj空间限制512M根本开不下。只有loj过了(loj牛逼!!!)

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 #include <string>
  5 #include <queue>
  6 
  7 template <class T> inline void read(T &x) {
  8     x = 0;
  9     char c = getchar();
 10     while(c < '0' || c > '9') {
 11         c = getchar();
 12     }
 13     while(c >= '0' && c <= '9') {
 14         x = (x << 3) + (x << 1) + c - 48;
 15         c = getchar();
 16     }
 17     return;
 18 }
 19 
 20 typedef long long LL;
 21 const int N = 1000010, M = 4000010;
 22 using std::string;
 23 
 24 struct SGT { // 两个线段树合并
 25     int tot, sum[M * 6], ls[M * 6], rs[M * 6], rt[M * 6], p, k;
 26     std::queue<int> Q;
 27     inline int np() {
 28         if(Q.empty()) {
 29             return ++tot;
 30         }
 31         int t = Q.front();
 32         Q.pop();
 33         sum[t] = ls[t] = rs[t] = 0;
 34         return t;
 35     }
 36     void add(int l, int r, int &o) {
 37         if(!o) {
 38             o = np();
 39         }
 40         if(l == r) {
 41             sum[o]++;
 42             return;
 43         }
 44         int mid = (l + r) >> 1;
 45         if(p <= mid) {
 46             add(l, mid, ls[o]);
 47         }
 48         else {
 49             add(mid + 1, r, rs[o]);
 50         }
 51         sum[o] = sum[ls[o]] + sum[rs[o]];
 52         return;
 53     }
 54     int merge(int x, int y) {
 55         if(!x || !y) {
 56             return x | y;
 57         }
 58         int z = np();
 59         sum[z] = sum[x] + sum[y];
 60         ls[z] = merge(ls[x], ls[y]);
 61         rs[z] = merge(rs[x], rs[y]);
 62         Q.push(x);
 63         Q.push(y);
 64         return z;
 65     }
 66     int ask(int L, int R, int l, int r, int o) {
 67         if(!o) {
 68             return 0;
 69         }
 70         if(L <= l && r <= R) {
 71             return sum[o];
 72         }
 73         int mid = (l + r) >> 1, ans = 0;
 74         if(L <= mid) {
 75             ans += ask(L, R, l, mid, ls[o]);
 76         }
 77         if(mid < R) {
 78             ans += ask(L, R, mid + 1, r, rs[o]);
 79         }
 80         return ans;
 81     }
 82     inline void exmerge(int x, int y) {
 83         rt[x] = merge(rt[x], rt[y]);
 84         return;
 85     }
 86     int getK(int l, int r, int o) {
 87         if(l == r) {
 88             return r;
 89         }
 90         int mid = (l + r) >> 1;
 91         if(k <= sum[ls[o]]) {
 92             return getK(l, mid, ls[o]);
 93         }
 94         else {
 95             k -= sum[ls[o]];
 96             return getK(mid + 1, r, rs[o]);
 97         }
 98     }
 99 }rt, st;
100 
101 string str[N];
102 char ss[N], s[N];
103 int tot = 1, fail[M], len[M], e[M], tr[M][26], top, n, m, nodel[N], noder[N], edgenex[M], edgev[M];
104 LL nodea[N];
105 
106 inline void add(int x, int y) {
107     top++;
108     edgev[top] = y;
109     edgenex[top] = e[x];
110     e[x] = top;
111     return;
112 }
113 
114 inline int split(int p, int f) {
115     int Q = tr[p][f];
116     int nQ = ++tot;
117     len[nQ] = len[p] + 1;
118     fail[nQ] = fail[Q];
119     fail[Q] = nQ;
120     memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
121     while(tr[p][f] == Q) {
122         tr[p][f] = nQ;
123         p = fail[p];
124     }
125     return nQ;
126 }
127 
128 inline int insert(int p, char c) {
129     int f = c - 'a';
130     if(tr[p][f]) {
131         int Q = tr[p][f];
132         if(len[Q] == len[p] + 1) {
133             return Q;
134         }
135         return split(p, f);
136     }
137     int np = ++tot;
138     len[np] = len[p] + 1;
139     while(p && !tr[p][f]) {
140         tr[p][f] = np;
141         p = fail[p];
142     }
143     if(!p) {
144         fail[np] = 1;
145     }
146     else {
147         int Q = tr[p][f];
148         if(len[Q] == len[p] + 1) {
149             fail[np] = Q;
150         }
151         else {
152             fail[np] = split(p, f);
153         }
154     }
155     return np;
156 }
157 
158 void work(int l, int r, int &o, int x) {
159     if(!o || !st.sum[o]) {
160         if(o) {
161             st.Q.push(o);
162         }
163         o = 0;
164         return;
165     }
166     if(l == r) {
167         // node[r]
168         int sum = 0;
169         if(nodel[r] + len[fail[x]] <= noder[r]) {
170             sum = rt.ask(nodel[r] + len[fail[x]], noder[r], 1, n, rt.rt[x]);
171         }
172         if(!sum) {
173             nodea[r] += len[x] - len[fail[x]];
174         }
175         else {
176             int temp = 0;
177             if(nodel[r] + len[x] - 1 <= noder[r]) {
178                 temp = rt.ask(nodel[r] + len[x] - 1, noder[r], 1, n, rt.rt[x]);
179             }
180             if(temp) {
181                 //GG
182                 st.sum[o] = 0;
183                 st.Q.push(o);
184                 o = 0;
185                 return;
186             }
187             else {
188                 // add some...
189                 // find ->| (the right pos)
190                 rt.k = rt.ask(1, nodel[r] + len[x] - 2, 1, n, rt.rt[x]);
191                 int ed = rt.getK(1, n, rt.rt[x]);
192                 nodea[r] += len[x] - (ed - nodel[r] + 1);
193             }
194         }
195         return;
196     }
197     int mid = (l + r) >> 1;
198     if(st.ls[o]) {
199         work(l, mid, st.ls[o], x);
200     }
201     if(st.rs[o]) {
202         work(mid + 1, r, st.rs[o], x);
203     }
204     st.sum[o] = st.sum[st.ls[o]] + st.sum[st.rs[o]];
205     if(!st.sum[o]) {
206         st.Q.push(o);
207         o = 0;
208     }
209     return;
210 }
211 
212 void solve(int x) {
213     for(int i = e[x]; i; i = edgenex[i]) {
214         int y = edgev[i];
215         solve(y);
216         if(x > 1) {
217             rt.exmerge(x, y);
218         }
219     }
220     if(x > 1) {
221         work(1, m, st.rt[x], x);
222         if(fail[x] > 1) {
223             st.exmerge(fail[x], x);
224         }
225     }
226     return;
227 }
228 
229 int main() {
230 
231     //freopen("name.in", "r", stdin);
232     //freopen("name.out", "w", stdout);
233 
234     scanf("%s", ss);
235     n = strlen(ss);
236     int last = 1;
237     for(int i = 0; i < n; i++) {
238         last = insert(last, ss[i]);
239     }
240     read(m);
241     for(int i = 1; i <= m; i++) {
242         scanf("%s", s);
243         str[i] = (string)(s);
244         int t = strlen(s);
245         last = 1;
246         for(int j = 0; j < t; j++) {
247             last = insert(last, s[j]);
248         }
249         read(nodel[i]);
250         read(noder[i]);
251     }
252     //
253     int p = 1;
254     for(int i = 0; i < n; i++) {
255         p = tr[p][ss[i] - 'a'];
256         rt.p = i + 1;
257         rt.add(1, n, rt.rt[p]);
258     }
259     for(int i = 2; i <= tot; i++) {
260         add(fail[i], i);
261     }
262 
263     for(int i = 1; i <= m; i++) {
264         int t = str[i].size(), p = 1;
265         for(int j = 0; j < t; j++) {
266             // str[i][j]
267             p = tr[p][str[i][j] - 'a'];
268             st.p = i;
269             st.add(1, m, st.rt[p]);
270         }
271     }
272 
273     solve(1);
274 
275     for(int i = 1; i <= m; i++) {
276         printf("%lld
", nodea[i]);
277     }
278     return 0;
279 }
AC代码

正解不是广义SAM,是普通SAM,还不用离线......还是我太菜了>_<

时限4s,你们感受一下...尤其是跟下面那个AC代码的对比......

正解:

68分:对S和T分别建sam然后同时跑T。跑到一个位置的时候会有一个匹配长度lenth。这时给T的节点打上长度为lenth的标记。

最后拓扑序跑一遍T的sam,统计答案。总不同子串数 - 匹配子串数。

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 
  5 typedef long long LL;
  6 const int N = 500010, M = 10000010;
  7 
  8 struct Edge {
  9     int nex, v;
 10 }edge[N << 1]; int tp;
 11 
 12 int tr[N << 1][26], fail[N << 1], len[N << 1], last, tot;
 13 int e[N << 1], n, vis[N << 1], Time, f[N], vis2[N << 1], use[N << 1], vis3[N << 1];
 14 int rt[N << 1], ls[M], rs[M], cnt;
 15 char str[N], ss[N];
 16 
 17 inline void init() {
 18     tot = last = 1;
 19     return;
 20 }
 21 
 22 inline void add(int x, int y) {
 23     tp++;
 24     edge[tp].v = y;
 25     edge[tp].nex = e[x];
 26     e[x] = tp;
 27     return;
 28 }
 29 
 30 void insert(int p, int l, int r, int &o) {
 31     if(!o) o = ++cnt;
 32     use[o] = 1;
 33     if(l == r) {
 34         return;
 35     }
 36     int mid = (l + r) >> 1;
 37     if(p <= mid) insert(p, l, mid, ls[o]);
 38     else insert(p, mid + 1, r, rs[o]);
 39     return;
 40 }
 41 
 42 int merge(int x, int y) {
 43     if(!x || !y) return x | y;
 44     int o = ++cnt;
 45     use[o] = use[x] | use[y];
 46     ls[o] = merge(ls[x], ls[y]);
 47     rs[o] = merge(rs[x], rs[y]);
 48     return o;
 49 }
 50 
 51 inline void insert(char c, int id) {
 52     int f = c - 'a', p = last, np = ++tot;
 53     last = np;
 54     len[np] = len[p] + 1;
 55     ///insert(id, 1, n, rt[np]);
 56     while(p && !tr[p][f]) {
 57         tr[p][f] = np;
 58         p = fail[p];
 59     }
 60     if(!p) {
 61         fail[np] = 1;
 62     }
 63     else {
 64         int Q = tr[p][f];
 65         if(len[Q] == len[p] + 1) {
 66             fail[np] = Q;
 67         }
 68         else {
 69             int nQ = ++tot;
 70             len[nQ] = len[p] + 1;
 71             fail[nQ] = fail[Q];
 72             fail[Q] = fail[np] = nQ;
 73             memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 74             while(tr[p][f] == Q) {
 75                 tr[p][f] = nQ;
 76                 p = fail[p];
 77             }
 78         }
 79     }
 80     return;
 81 }
 82 
 83 void DFS_1(int x) {
 84     for(int i = e[x]; i; i = edge[i].nex) {
 85         int y = edge[i].v;
 86         DFS_1(y);
 87         rt[x] = merge(rt[x], rt[y]);
 88     }
 89     return;
 90 }
 91 
 92 namespace sam {
 93     int tot, len[N << 1], fail[N << 1], tr[N << 1][26], last, large[N << 1];
 94     int bin[N << 1], topo[N << 1];
 95     inline void clear() {
 96         for(int i = 1; i <= tot; i++) {
 97             memset(tr[i], 0, sizeof(tr[i]));
 98             large[i] = bin[i] = fail[i] = len[i] = 0;
 99         }
100         tot = last = 1;
101         return;
102     }
103     inline void insert(char c) {
104         //printf("insert : "); putchar(c); printf("
");
105         int f = c - 'a', p = last, np = ++tot;
106         last = np;
107         len[np] = len[p] + 1;
108         while(p && !tr[p][f]) {
109             tr[p][f] = np;
110             p = fail[p];
111         }
112         if(!p) {
113             fail[np] = 1;
114         }
115         else {
116             int Q = tr[p][f];
117             if(len[Q] == len[p] + 1) {
118                 fail[np] = Q;
119             }
120             else {
121                 int nQ = ++tot;
122                 len[nQ] = len[p] + 1;
123                 fail[nQ] = fail[Q];
124                 fail[Q] = fail[np] = nQ;
125                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
126                 while(tr[p][f] == Q) {
127                     tr[p][f] = nQ;
128                     p = fail[p];
129                 }
130             }
131         }
132         return;
133     }
134     inline LL sort() {
135         for(int i = 1; i <= tot; i++) {
136             bin[len[i]]++;
137         }
138         for(int i = 1; i <= tot; i++) {
139             bin[i] += bin[i - 1];
140         }
141         for(int i = 1; i <= tot; i++) {
142             topo[bin[len[i]]--] = i;
143         }
144         LL ans = 0;
145         for(int i = tot; i >= 2; i--) {
146             int x = topo[i];
147             if(large[x] > len[fail[x]]) {
148                 ans -= std::min(len[x], large[x]) - len[fail[x]];
149             }
150             ans += len[x] - len[fail[x]];
151             large[fail[x]] = std::max(large[fail[x]], large[x]);
152         }
153         return ans;
154     }
155 }
156 
157 int main() {
158 
159     freopen("in.in", "r", stdin);
160     freopen("my.out", "w", stdout);
161 
162     init();
163     scanf("%s", str);
164     n = strlen(str);
165     for(int i = 0; i < n; i++) {
166         insert(str[i], i + 1);
167     }
168     for(int i = 2; i <= tot; i++) add(fail[i], i);
169     //DFS_1(1);
170     /// build over
171     int q, x, y;
172     scanf("%d", &q);
173     for(Time = 1; Time <= q; Time++) {
174         //printf("i = %d 
", Time);
175         scanf("%s%d%d", ss, &x, &y);
176         int m = strlen(ss);
177         sam::clear();
178         for(int i = 0; i < m; i++) {
179             sam::insert(ss[i]);
180         }
181         /// match
182         int p1 = 1, p2 = 1, lenth = 0;
183         for(int i = 0; i < m; i++) {
184             int ff = ss[i] - 'a';
185             p2 = sam::tr[p2][ff];
186             while(p1 && !tr[p1][ff]) {
187                 p1 = fail[p1];
188                 lenth = len[p1];
189             }
190             if(!p1) {
191                 p1 = 1;
192             }
193             if(tr[p1][ff]) {
194                 p1 = tr[p1][ff];
195                 lenth++;
196             }
197             sam::large[p2] = std::max(sam::large[p2], lenth);
198         }
199         LL ans = sam::sort();
200         printf("%lld
", ans);
201     }
202     return 0;
203 }
68分代码

100分:写个匹配函数来判断能不能匹配。失配的话先不跳fail,而是lenth--。

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 
  5 typedef long long LL;
  6 const int N = 500010, M = 30000010;
  7 
  8 struct Edge {
  9     int nex, v;
 10 }edge[N << 1]; int tp;
 11 
 12 int tr[N << 1][26], fail[N << 1], len[N << 1], last, tot;
 13 int e[N << 1], n, Time, use[M];
 14 int rt[N << 1], ls[M], rs[M], cnt, X, Y;
 15 char str[N], ss[N];
 16 
 17 inline void init() {
 18     tot = last = 1;
 19     return;
 20 }
 21 
 22 inline void add(int x, int y) {
 23     tp++;
 24     edge[tp].v = y;
 25     edge[tp].nex = e[x];
 26     e[x] = tp;
 27     return;
 28 }
 29 
 30 void insert(int p, int l, int r, int &o) {
 31     if(!o) o = ++cnt;
 32     use[o] = 1;
 33     if(l == r) {
 34         return;
 35     }
 36     int mid = (l + r) >> 1;
 37     if(p <= mid) insert(p, l, mid, ls[o]);
 38     else insert(p, mid + 1, r, rs[o]);
 39     return;
 40 }
 41 
 42 int merge(int x, int y) {
 43     if(!x || !y) return x | y;
 44     int o = ++cnt;
 45     use[o] = use[x] | use[y];
 46     ls[o] = merge(ls[x], ls[y]);
 47     rs[o] = merge(rs[x], rs[y]);
 48     return o;
 49 }
 50 
 51 inline void insert(char c, int id) {
 52     int f = c - 'a', p = last, np = ++tot;
 53     last = np;
 54     len[np] = len[p] + 1;
 55     insert(id, 1, n, rt[np]);
 56     while(p && !tr[p][f]) {
 57         tr[p][f] = np;
 58         p = fail[p];
 59     }
 60     if(!p) {
 61         fail[np] = 1;
 62     }
 63     else {
 64         int Q = tr[p][f];
 65         if(len[Q] == len[p] + 1) {
 66             fail[np] = Q;
 67         }
 68         else {
 69             int nQ = ++tot;
 70             len[nQ] = len[p] + 1;
 71             fail[nQ] = fail[Q];
 72             fail[Q] = fail[np] = nQ;
 73             memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 74             while(tr[p][f] == Q) {
 75                 tr[p][f] = nQ;
 76                 p = fail[p];
 77             }
 78         }
 79     }
 80     return;
 81 }
 82 
 83 void DFS_1(int x) {
 84     for(int i = e[x]; i; i = edge[i].nex) {
 85         int y = edge[i].v;
 86         DFS_1(y);
 87         rt[x] = merge(rt[x], rt[y]);
 88     }
 89     return;
 90 }
 91 
 92 inline bool ask(int L, int R, int l, int r, int o) {
 93     //printf("ask [%d %d] [%d %d] o = %d  sum = %d 
", L, R, l, r, o, use[o]);
 94     if(!o) return 0;
 95     if(L <= l && r <= R) return use[o];
 96     int mid = (l + r) >> 1; bool ans = 0;
 97     if(L <= mid) ans |= ask(L, R, l, mid, ls[o]);
 98     if(mid < R) ans |= ask(L, R, mid + 1, r, rs[o]);
 99     return ans;
100 }
101 
102 inline bool match(int p, int lenth, int f) {
103     if(!tr[p][f]) return 0;
104     return ask(X + lenth, Y, 1, n, rt[tr[p][f]]);
105 }
106 
107 namespace sam {
108     int tot, len[N << 1], fail[N << 1], tr[N << 1][26], last, large[N << 1];
109     int bin[N << 1], topo[N << 1];
110     inline void clear() {
111         for(int i = 1; i <= tot; i++) {
112             memset(tr[i], 0, sizeof(tr[i]));
113             large[i] = bin[i] = fail[i] = len[i] = 0;
114         }
115         tot = last = 1;
116         return;
117     }
118     inline void insert(char c) {
119         //printf("insert : "); putchar(c); printf("
");
120         int f = c - 'a', p = last, np = ++tot;
121         last = np;
122         len[np] = len[p] + 1;
123         while(p && !tr[p][f]) {
124             tr[p][f] = np;
125             p = fail[p];
126         }
127         if(!p) {
128             fail[np] = 1;
129         }
130         else {
131             int Q = tr[p][f];
132             if(len[Q] == len[p] + 1) {
133                 fail[np] = Q;
134             }
135             else {
136                 int nQ = ++tot;
137                 len[nQ] = len[p] + 1;
138                 fail[nQ] = fail[Q];
139                 fail[Q] = fail[np] = nQ;
140                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
141                 while(tr[p][f] == Q) {
142                     tr[p][f] = nQ;
143                     p = fail[p];
144                 }
145             }
146         }
147         return;
148     }
149     inline LL sort() {
150         for(int i = 1; i <= tot; i++) {
151             bin[len[i]]++;
152         }
153         for(int i = 1; i <= tot; i++) {
154             bin[i] += bin[i - 1];
155         }
156         for(int i = 1; i <= tot; i++) {
157             topo[bin[len[i]]--] = i;
158         }
159         LL ans = 0;
160         for(int i = tot; i >= 2; i--) {
161             int x = topo[i];
162             if(large[x] > len[fail[x]]) {
163                 ans -= std::min(len[x], large[x]) - len[fail[x]];
164             }
165             ans += len[x] - len[fail[x]];
166             large[fail[x]] = std::max(large[fail[x]], large[x]);
167         }
168         return ans;
169     }
170 }
171 
172 int main() {
173 
174     //printf("%d 
", (sizeof(ls) * 3) / 1048576);
175 
176     freopen("name.in", "r", stdin);
177     freopen("name.out", "w", stdout);
178 
179     init();
180     scanf("%s", str);
181     n = strlen(str);
182     for(int i = 0; i < n; i++) {
183         insert(str[i], i + 1);
184     }
185     for(int i = 2; i <= tot; i++) add(fail[i], i);
186     DFS_1(1);
187     /// build over
188     int q;
189     scanf("%d", &q);
190     for(Time = 1; Time <= q; Time++) {
191         //printf("i = %d 
", Time);
192         scanf("%s%d%d", ss, &X, &Y);
193         int m = strlen(ss);
194         sam::clear();
195         for(int i = 0; i < m; i++) {
196             sam::insert(ss[i]);
197         }
198         /// match
199         int p1 = 1, p2 = 1, lenth = 0;
200         for(int i = 0; i < m; i++) {
201             int ff = ss[i] - 'a';
202             p2 = sam::tr[p2][ff];
203             while(p1 && !match(p1, lenth, ff)) {
204                 if(lenth) lenth--;
205                 if(lenth == len[fail[p1]]) p1 = fail[p1];
206             }
207             if(!p1) {
208                 p1 = 1;
209             }
210             else {
211                 p1 = tr[p1][ff];
212                 lenth++;
213             }
214             sam::large[p2] = std::max(sam::large[p2], lenth);
215             //printf("lneth = %d 
", lenth);
216         }
217         LL ans = sam::sort();
218         printf("%lld
", ans);
219     }
220     return 0;
221 }
AC代码
原文地址:https://www.cnblogs.com/huyufeifei/p/10335861.html