洛谷P4384 制胡窜

这题TM是计数神题......SAM就是个板子,别脑残写错就完事了。有个技巧是快速定位子串,倍增即可。

考虑反着来,就是两个断点切割所有串,求方案数。

大概分类讨论一下......先特判掉一些情况。然后考虑最左右的两个串是否相交。

相交的情况比较友善,先特殊统计有断点在交集中。之后枚举第一个断点切割了1 ~ i个串。第二个断点就切割i+1 ~ m个串。

然后写出一个∑的式子,拆开之后发现维护right集合平方和和right集合相邻元素的乘积即可。

不相交的麻烦点(极其麻烦...)考虑枚举第一个断点切割了1~i个串,这里的i的限制是L ~ R,可以先求出来。

然后继续化简一个∑式子,最后发现要维护的东西差不多。于是这题做完了。

感想:对拍真好用.jpg

  1 #include <bits/stdc++.h>
  2 
  3 typedef long long LL;
  4 const int N = 200010, M = 10000010;
  5 
  6 struct Edge {
  7     int nex, v;
  8 }edge[N]; int tp;
  9 
 10 int tr[N][10], fail[N], len[N], rt[N], last = 1, tot = 1;
 11 int ls[M], rs[M], cnt, e[N], ed[N], fa[N][20], pw[N];
 12 LL sum0[M], sum2[M], sumd[M], large[M], small[M];
 13 int n, q;
 14 char str[N];
 15 
 16 inline void add(int x, int y) {
 17     tp++;
 18     edge[tp].v = y;
 19     edge[tp].nex = e[x];
 20     e[x] = tp;
 21     return;
 22 }
 23 
 24 inline void pushup(int o) {
 25     if(!ls[o] && !rs[o]) return;
 26     sum0[o] = sum0[ls[o]] + sum0[rs[o]];
 27     sum2[o] = sum2[ls[o]] + sum2[rs[o]];
 28     sumd[o] = sumd[ls[o]] + sumd[rs[o]];
 29     if(ls[o] && rs[o]) sumd[o] += small[rs[o]] * large[ls[o]];
 30     large[o] = std::max(large[ls[o]], large[rs[o]]);
 31     small[o] = std::min(small[ls[o]], small[rs[o]]);
 32     return;
 33 }
 34 
 35 int merge(int x, int y) {
 36     if(!x || !y) return x | y;
 37     int o = ++cnt;
 38     large[o] = large[x];
 39     small[o] = small[x];
 40     sum0[o] = sum0[x];
 41     sum2[o] = sum2[x];
 42     sumd[o] = sumd[x];
 43     ls[o] = merge(ls[x], ls[y]);
 44     rs[o] = merge(rs[x] ,rs[y]);
 45     pushup(o);
 46     return o;
 47 }
 48 
 49 void insert(int p, int l, int r, int &o) {
 50     if(!o) o = ++cnt;
 51     if(l == r) {
 52         large[o] = small[o] = r;
 53         sum0[o] = 1;
 54         sum2[o] = 1ll * r * r;
 55         sumd[o] = 0;
 56         return;
 57     }
 58     int mid = (l + r) >> 1;
 59     if(p <= mid) insert(p, l, mid, ls[o]);
 60     else insert(p, mid + 1, r, rs[o]);
 61     pushup(o);
 62     return;
 63 }
 64 
 65 inline void insert(char c, int id) {
 66     int f = c - '0', p = last, np = ++tot;
 67     last = np;
 68     insert(id, 1, n, rt[np]);
 69     len[np] = len[p] + 1;
 70     while(p && !tr[p][f]) {
 71         tr[p][f] = np;
 72         p = fail[p];
 73     }
 74     if(!p) {
 75         fail[np] = 1;
 76     }
 77     else {
 78         int Q = tr[p][f];
 79         if(len[Q] == len[p] + 1) {
 80             fail[np] = Q;
 81         }
 82         else {
 83             int nQ = ++tot;
 84             len[nQ] = len[p] + 1;
 85             fail[nQ] = fail[Q];
 86             fail[Q] = fail[np] = nQ;
 87             memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 88             while(tr[p][f] == Q) {
 89                 tr[p][f] = nQ;
 90                 p = fail[p];
 91             }
 92         }
 93     }
 94     return;
 95 }
 96 
 97 int ask0(int L, int R, int l, int r, int o) {
 98     if(!o) return 0;
 99     if(L <= l && r <= R) return sum0[o];
100     int mid = (l + r) >> 1, ans = 0;
101     if(L <= mid) ans += ask0(L, R, l, mid, ls[o]);
102     if(mid < R) ans += ask0(L, R, mid + 1, r, rs[o]);
103     return ans;
104 }
105 
106 LL ask2(int L, int R, int l, int r, int o) {
107     if(!o) return 0;
108     if(L <= l && r <= R) return sum2[o];
109     int mid = (l + r) >> 1;
110     LL ans = 0;
111     if(L <= mid) ans += ask2(L, R, l, mid, ls[o]);
112     if(mid < R) ans += ask2(L, R, mid + 1, r, rs[o]);
113     return ans;
114 }
115 
116 LL askd(int L, int R, int l, int r, int o) {
117     if(!o) return 0;
118     if(L <= l && r <= R) return sumd[o];
119     int mid = (l + r) >> 1;
120     if(R <= mid) return askd(L, R, l, mid, ls[o]);
121     if(mid < L) return askd(L, R, mid + 1, r, rs[o]);
122     LL ans = askd(L, R, l, mid, ls[o]) + askd(L, R, mid + 1, r, rs[o]);
123     if(ls[o] && rs[o]) {
124         ans += large[ls[o]] * small[rs[o]];
125     }
126     return ans;
127 }
128 
129 int getKpos(int k, int l, int r, int o) {
130     if(l == r) return r;
131     int mid = (l + r) >> 1;
132     if(k <= sum0[ls[o]]) return getKpos(k, l, mid, ls[o]);
133     else return getKpos(k - sum0[ls[o]], mid + 1, r, rs[o]);
134 }
135 
136 void DFS(int x) {
137     for(int i = e[x]; i; i = edge[i].nex) {
138         int y = edge[i].v;
139         fa[y][0] = x;
140         DFS(y);
141         rt[x] = merge(rt[x], rt[y]);
142     }
143     return;
144 }
145 
146 inline void prework() {
147     for(int i = 2; i <= tot; i++) pw[i] = pw[i >> 1] + 1;
148     for(int j = 1; j <= pw[tot]; j++) {
149         for(int i = 1; i <= tot; i++) {
150             fa[i][j] = fa[fa[i][j - 1]][j - 1];
151         }
152     }
153     return;
154 }
155 
156 inline int getPos(int x, int Len) {
157     int t = pw[tot];
158     while(t >= 0 && len[fa[x][0]] >= Len) {
159         if(len[fa[x][t]] >= Len) {
160             x = fa[x][t];
161         }
162         t--;
163     }
164     return x;
165 }
166 
167 inline LL getAns(int Len, int root) {
168     
169     if(Len == 1) return 0;
170     
171     int r1 = small[root], rn = large[root];
172     int l1 = r1 - Len + 1, ln = rn - Len + 1;
173     
174     if(ask0(r1 + Len - 1, ln, 1, n, root)) {
175         return 0;
176     }
177     LL ans = 0;
178     if(r1 > ln) { /// cross 
179         ans = sum2[root] - sumd[root] - 1ll * r1 * rn;
180         LL temp = (r1 - ln);
181         ans += (temp - 1) * temp / 2 + temp * (n - temp - 1);
182     }
183     else {
184         int ql = ask0(1, r1 + Len - 2, 1, n, root), qr = ask0(1, ln, 1, n, root);
185         int L = qr, R = ql;
186         int rL = getKpos(L, 1, n, root), rR = getKpos(R, 1, n, root), nexr = getKpos(R + 1, 1, n, root);
187         ans = 1ll * (r1 - rR + Len - 1) * (nexr - ln) + 1ll * rL * ln - 1ll * rR * ln;
188         ans += ask2(rL + 1, rR, 1, n, root) - askd(rL, rR, 1, n, root);
189     }
190     return ans;
191 }
192 
193 int main() {
194     memset(small, 0x3f, sizeof(small));
195     scanf("%d%d", &n, &q);
196     scanf("%s", str + 1);
197     for(int i = 1; i <= n; i++) {
198         insert(str[i], i);
199     }
200     int p = 1;
201     for(int i = 1; i <= n; i++) {
202         int f = str[i] - '0';
203         p = tr[p][f];
204         ed[i] = p;
205         //printf("i = %d p = %d 
", i, p);
206     }
207     for(int i = 2; i <= tot; i++) add(fail[i], i);
208     DFS(1);
209     prework();
210     LL SUM = 1ll * (n - 1) * (n - 2) / 2;
211     for(int i = 1, l, r; i <= q; i++) {
212         scanf("%d%d", &l, &r);
213         LL t = getAns(r - l + 1, rt[getPos(ed[r], r - l + 1)]);
214         printf("%lld
", SUM - t);
215     }
216     return 0;
217 } 
AC代码
原文地址:https://www.cnblogs.com/huyufeifei/p/10508986.html