LOJ 2065 [SDOI 2010]模式字符串

题意

hhhhhhhh并不知道为啥,只有三组数据,不过第三组数据够强,达到了题目描述的\(\sum n\leq10^6\)

实际上\(n=1e6\)才算强吧

只有一个全局询问,考虑点分治,统计串的时候因为\(<u,v>\)为有序对,所以solve一个点的时候

要分别记录x到当前点时作为prefix和suffix的不同hash值,同时对预处理出来的pre和suf值做匹配

如果匹配上了就找对应的在之前子树里的残串,这样对于每个串匹配前后缀统计答案可以做到\(O(1)\)

这样的话要开spre和ssuf表示在之前的子树里的\(\mod m\)意义下残串的数量

这样的话calc每个点复杂度就是线性的,总复杂度\(O(n\log n)\)

注意将一个子树的匹配前后缀计入的时候要dfs完整个子树再进行记录,不然会有重复答案。

我这里是把calc的点\(x\)作为后缀计入,所以\(spre_m\)(由\(x\)和其子树中的一条路径恰好是模式串的循环)一开始为1

当然这样的话如果\(x\)恰好为模式串最后一位对于缺少\(x\)作为后缀的前缀也要计入,即\(ssuf_1\)初始值为1

#include<cstdio>
#include<cstring>
const int N = 1e5+7;
typedef long long LL;
typedef unsigned long long uLL;
#define R register
inline int max(int a, int b) {
  return a > b ? a : b;
}
int last[N], cnt, n, m, up, vis[N];
char A[N];
uLL S[N], pre[N], suf[N];
LL ans;
struct Edge {
  int to, nxt;
}e[N*2];
inline void add(int u, int v) {
  e[++cnt].nxt = last[u], e[cnt].to = v, last[u] = cnt;
}
int SIZ, mxson[N], siz[N], rt;
#define BASE 233uLL
#define cerr(x) printf("%d ", x)
//#define cerr(x, y) printf("%d %d\n", x, y)
#define debug printf("GG\n")
void findrt(int x, int fa) {
  siz[x] = 1, mxson[x] = 0;
  for (int o = last[x]; o; o = e[o].nxt) {
    int to = e[o].to;
    if (to == fa || vis[to]) continue;
    findrt(to, x);
    siz[x] += siz[to];
    mxson[x] = max(mxson[x], siz[to]);
  } mxson[x] = max(SIZ - siz[x], mxson[x]);
  if (mxson[x] < mxson[rt]) rt = x;
}
uLL spre[N], ssuf[N];
int poi, total;
struct Node {
  int pos, id;
}p[N];
/* 
inline uLL FST(uLL x, uLL k) {
  uLL res = 1uLL;
  while (k) {
    if (k & 1uLL) res = res * x;
    x = x * x; k >>= 1uLL; 
  } return res;
}*/
uLL FST[N]; 
void getdis(int x, int fa, uLL suffix, uLL prefix, uLL dep) {
  //if (dep > up) return;
  //cerr(x);
  uLL omgx = suffix + S[x] * FST[dep];//, omgy = prefix * BASE + S[x];
  uLL omgy = prefix + S[x] * FST[dep - 1];
  if (omgx == suf[up - dep])
    ans += spre[m - (dep + 1) % m], p[++total] = (Node){((dep + 1) % m) == 0 ? m : ((dep + 1) % m), 1};
  if (omgy == pre[dep])
    ans += ssuf[m - dep % m], p[++total] = (Node){(dep % m == 0) ? m : (dep % m), 2};
  for (int o = last[x]; o; o = e[o].nxt) {
    int to = e[o].to;
    if (to == fa || vis[to]) continue;
    getdis(to, x, omgx, omgy, dep + 1);
  }
}
void div(int x) {
  //cerr(x);
  if (S[x] == A[m]) ssuf[1]++;
  spre[m]++;
  total = 0, poi = 0;
  for (R int o = last[x]; o; o = e[o].nxt) {
    int to = e[o].to;
    if (vis[to]) continue;
    getdis(to, x, S[x], 0, 1);
    while (poi < total) {
      poi++;
      if (p[poi].id == 1) ssuf[p[poi].pos]++;
      if (p[poi].id == 2) spre[p[poi].pos]++;
    }
  }
  while (poi >= 1) {
    if (p[poi].id == 1) ssuf[p[poi].pos]--;
    if (p[poi].id == 2) spre[p[poi].pos]--;
    poi--;
  }
  if (S[x] == A[m]) ssuf[1]--;
  spre[m]--;
  //cerr(ans, x);
  //printf("%d %d\n", ans, x);
}
void solve(int x) {
  //cerr(x);
  //printf("%d %d\n", ans, x);
  vis[x] = 1, div(x);
  //printf("%d %d\n", ans, x);
  for (int o = last[x]; o; o = e[o].nxt) {
    int to = e[o].to;
    if (vis[to]) continue;
    SIZ = mxson[0] = siz[to], rt = 0;
    findrt(to, x);
    solve(rt);
  }
  vis[x] = 0;
}
int main() {
  //freopen("random.in", "r", stdin);
  //freopen("check.out", "w", stdout);
  int T;
  scanf("%d", &T);
  FST[0] = 1uLL;
  for (int i = 1; i <= 100000; i++)
    FST[i] = FST[i - 1] * BASE;
  while(T--) {
    cnt = 0;
    ans = 0;
    memset(e, 0, sizeof(e));
    memset(last, 0, sizeof(last));
    memset(A, 0, sizeof(A));
    memset(S, 0, sizeof(S));
    memset(pre, 0, sizeof(pre));
    memset(suf, 0, sizeof(suf));
    memset(mxson, 0, sizeof(mxson));
    memset(siz, 0, sizeof(siz));
    //memset(vis, 0, sizeof(vis));
    memset(spre, 0, sizeof(spre));
    memset(ssuf, 0, sizeof(ssuf));
    scanf("%d%d", &n, &m);
    scanf("%s", A + 1);
    up = (n / m) * m;
    for (R int i = 1; i <= n; i++) S[i] = A[i];
    for (R int i = 1, x, y; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
    scanf("%s", A + 1);
    for (R int i = 1; i <= up; i++) pre[i] = pre[i - 1] * BASE + A[((i % m) == 0) ? m : (i % m)];
    for (R int i = up; i >= 1; i--) suf[i] = suf[i + 1] * BASE + A[((i % m) == 0) ? m : (i % m)];
    rt = 0, mxson[0] = SIZ = n;
    findrt(1, 0);
    solve(rt);
    printf("%lld\n", ans);
  } //debug;
}

// 97 ~ 122
/* 
2
10 4
aaaaaaabbb
2 1
3 1
4 1
5 2
6 4
7 2
8 3
9 7
10 7
baaa 
10 4
aaaaaaabbb
2 1
3 1
4 1
5 2
6 4
7 2
8 3
9 7
10 7
baaa 

*/
原文地址:https://www.cnblogs.com/cjc030205/p/11638106.html