CF666E Forensic Examination——SAM+线段树合并+倍增

RemoteJudge

题目大意

给你一个串(S)以及一个字符串数组(T[1...m])(q)次询问,每次问(S)的子串(S[p_l...p_r])(T[l...r])中的哪个串里的出现次数最多,并输出出现次数。
如有多解输出最靠前的那一个。

思路

第一次见到在(parent tree)上线段树合并的题,感觉好妙
先对(T)建一个广义后缀自动机,考虑对(SAM)上的每一个结点建一颗线段树,值域为([1,m]),维护出现次数最多的串的位置和次数。又因为(endpos)集合(好像也叫(right)集合)有这么一个性质:一个结点的(endpos)集合即为其在(parent tree)上子结点的并集,所以我们在建树时只需要上一个线段树合并即可。
上面的那个思路貌似是个套路?
然后来处理询问,显然我们只需要在(S[p_l...p_r])对应的结点的线段树上查(l-r)的最大值就行了,但如果直接拿(S[p_l...p_r])(SAM)上匹配,复杂度绝壁不对QwQ。于是我们考虑先把整个(S)(SAM)上匹配,需要查哪个子串时通过跳(suflink)来找。具体一下,就是对于(S)的一个前缀(S[1...j]),如果它最后匹配到了结点(u),匹配的长度为(len),然后我们要查的子串是(S[i...j]),就从(u)开始跳(suflink)直到一个(maxlen)大于等于(j-i+1)且深度最小的结点,记其为(v),要查的就是(v)那棵线段树的答案
最后发现跳(suflink)的过程可以用倍增来优化,然后就没了
吐槽1.为什么我写离线的就会(WA),在线的就过了
吐槽2.下午三点多写完,然后(CF)

咕到了六点多,然后交了一发,(WA)了,我...

#include <algorithm>
#include  <iostream>
#include   <cstdlib>
#include   <cstring>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <cmath>
#include     <ctime>
#include     <queue>
#include       <map>
#include       <set>

using namespace std;

#define ull unsigned long long
#define pii pair<int, int>
#define uint unsigned int
#define mii map<int, int>
#define lbd lower_bound
#define ubd upper_bound
#define INF 0x3f3f3f3f
#define IINF 0x3f3f3f3f3f3f3f3fLL
#define vi vector<int>
#define ll long long
#define mp make_pair
#define pb push_back
#define re register
#define il inline

#define MAXS 500000
#define M 50000
#define Q 500000
#define MAXT 100000
#define LIM 16

char S[MAXS+5], T[MAXT+5];
int n, m, q;
int nxt[26][2*MAXT+5], maxlen[2*MAXT+5], link[2*MAXT+5], in[2*MAXT+5], nid1, lst;
int nid2, root[2*MAXT+5], ch[2][160*MAXT+5];
vi G[2*MAXT+5];
int f[2*MAXT+5][LIM+1];
int tar[MAXS+5], ml[MAXS+5];

struct Data {
  int w, pos;
  friend Data operator + (Data lhs, Data rhs) {
    if(lhs.w > rhs.w) return lhs;
    else {
      if(lhs.w == rhs.w && lhs.pos < rhs.pos) return lhs;
      else return rhs;
    }
  }
  bool operator < (const Data &rhs) const {
    return w == rhs.w ? pos > rhs.pos : w < rhs.w;
  }
}nodes[160*MAXT+5];

void init() {
  nid1 = lst = 1;
  nid2 = 0;
}

void pushup(int o) {
  nodes[o] = nodes[ch[0][o]]+nodes[ch[1][o]];
}

void add(int &u, int l, int r, int x) {
  if(!u) u = ++nid2;
  if(l == r) {
    nodes[u] = Data{++nodes[u].w, nodes[u].pos = l};
    return ;
  }
  int mid = (l+r)>>1;
  if(x <= mid) add(ch[0][u], l, mid, x);
  else add(ch[1][u], mid+1, r, x);
  pushup(u);
}

int merge(int x, int y, int l, int r) {
  if(!x || !y) return x | y;
  int now = ++nid2;
  if(l == r) {
    nodes[now] = Data{nodes[x].w+nodes[y].w, nodes[x].pos};
    return now;
  }
  int mid = (l+r)>>1;
  ch[0][now] = merge(ch[0][x], ch[0][y], l, mid);
  ch[1][now] = merge(ch[1][x], ch[1][y], mid+1, r);
  pushup(now);
  return now;
}

Data query(int o, int l, int r, int L, int R) {
  if(!o) return Data{0, 0};
  if(L <= l && r <= R) return nodes[o];
  int mid = (l+r)>>1;
  Data ret{0, 0};
  if(L <= mid) ret = ret+query(ch[0][o], l, mid, L, R);
  if(R > mid) ret = ret+query(ch[1][o], mid+1, r, L, R);
  return ret;
}

void extend(int c, int id) {
  int cur = ++nid1;
  maxlen[cur] = maxlen[lst]+1;
  add(root[cur], 1, m, id);
  while(lst && !nxt[c][lst]) nxt[c][lst] = cur, lst = link[lst];
  if(!lst) link[cur] = 1;
  else {
    int p = lst, q = nxt[c][lst];
    if(maxlen[q] == maxlen[p]+1) link[cur] = q;
    else {
      int clone = ++nid1;
      maxlen[clone] = maxlen[p]+1;
      link[clone] = link[q], link[q] = link[cur] = clone;
      for(int i = 0; i < 26; ++i) nxt[i][clone] = nxt[i][q];
      while(p && nxt[c][p] == q) nxt[c][p] = clone, p = link[p];
    }
  }
  lst = cur;
}

void insert(int id) {
  int t = strlen(T+1);
  lst = 1;
  for(int i = 1; i <= t; ++i) extend(T[i]-'a', id);
}

void build(int u, int fa) {
  f[u][0] = fa;
  for(int i = 1; i <= LIM; ++i) f[u][i] = f[f[u][i-1]][i-1];
  for(int i = 0, v; i < G[u].size(); ++i) {
    v = G[u][i];
    build(v, u);
    root[u] = merge(root[u], root[v], 1, m);
  }
}

void pre() {
  n = strlen(S+1);
  int u = 1, len = 0;
  for(int i = 1; i <= n; ++i) {
    if(nxt[S[i]-'a'][u]) u = nxt[S[i]-'a'][u], len++;
    else {
      while(u && !nxt[S[i]-'a'][u]) u = link[u];
      if(!u) u = 1, len = 0;
      else len = maxlen[u]+1, u = nxt[S[i]-'a'][u];
    }
    tar[i] = u, ml[i] = len;
  }
}

int main() {
  scanf("%s%d", S+1, &m);
  init();
  for(int i = 1; i <= m; ++i) scanf("%s", T+1), insert(i);
  for(int i = 2; i <= nid1; ++i) G[link[i]].pb(i);
  build(1, 0);
  pre();
  scanf("%d", &q);
  for(int i = 1, l, r, pl, pr, L; i <= q; ++i) {
    scanf("%d%d%d%d", &l, &r, &pl, &pr);
    L = pr-pl+1;
    if(L > ml[pr]) printf("%d 0
", l);
    else {
      int u = tar[pr];
      for(int k = LIM; ~k; --k) if(maxlen[f[u][k]] >= L) u = f[u][k];
      Data ret = query(root[u], 1, m, l, r);
      if(ret.w == 0) ret.pos = l;
      printf("%d %d
", ret.pos, ret.w);
    }
  }
  return 0;
}
原文地址:https://www.cnblogs.com/dummyummy/p/10932735.html