2019 CCPC 网络选拔 Kth-occurrence

题意

给出一个字符串,每次询问其一个子串([S_l,S_r])在原串中第(k)次出现所在的位置(开头位置)


解法

题意很简洁,思路也很简洁

就是代码巨难打

总之这道题还是让我很大程度上加深了对于(SAM)的认识啦,还去学了一下线段树合并

首先,根据后缀自动机的性质我们能知道题目所要求的的实际上是

([S_l,S_r])所代表串所在后缀自动机的结点(endpos)集合中的第(k)个数

求区间第(k)大?权值线段树上啦

我们知道,对于后缀自动机上的某个结点,其(endpos)集合是它在(parent)树上所有儿子(endpos)集合的并

我们对于(parent)树上的每个节点,开一颗权值线段树,每个前缀初始化为其尾位置

那么对于我们想要求出某个结点的(endpos)集合,只要将其儿子的权值线段树合并到它上面来就行了

为了节省空间,这里每颗权值线段树都是动态开点的

那么我们如何查询一个子串([S_l,S_r])在原串的后缀自动机上对应的节点呢?

我们可以采用树上倍增

先预处理所有前缀所对应的结点编号,建出(parent)树,预处理倍增数组(f)并进行一轮线段树合并求出所有节点的(endpos)集合

那么在查询([S_l,S_r])时,我们先找到([1,S_r])对应的结点,然后开始倍增跳

([S_l,S_r])对应节点所代表的的最长串一定是([x,S_r]),我们需要找到一个最大的(x)使得([S_l,S_r])([x,S_r])的后缀

倍增判断即可。其实我们完全可以把倍增理解为一种不断二分的过程,一次次的接近答案(因为在跳(fa)的过程中(len)是单调递减的)

找到对应的结点后在那个节点的权值线段树上查询第(k)大即可


代码

都封装了,打得巨好理解

#include <cstdio>
#include <cstring>

using namespace std;

const int N = 1e6 + 10;

int T, n, q;

char a[N];

struct segTree {
	
	int sz;
	int ls[N << 2], rs[N << 2], val[N << 2];
	
	void clear() {
		sz = 0;
	}
	
	void update(int cur) {
		val[cur] = val[ls[cur]] + val[rs[cur]];	
	}
	
	int newnode() {
		++sz;
		val[sz] = ls[sz] = rs[sz] = 0;
		return sz;
	}
	
	void mkchain(int& cur, int l, int r, int k) {
		if (!cur)	cur = newnode();
		if (l == r)	return val[cur]++, void();
		int mid = l + r >> 1;
		if (k <= mid)	
			mkchain(ls[cur], l, mid, k);
		else 
			mkchain(rs[cur], mid + 1, r, k);
		update(cur);
	}
	
	int merge(int x1, int x2) {
		if (!x1 || !x2)	return x1 + x2;
		int x = ++sz;
		ls[x] = merge(ls[x1], ls[x2]);
		rs[x] = merge(rs[x1], rs[x2]);
		val[x] = val[x1] + val[x2];
		return x;
	}
			
	int query(int cur, int l, int r, int k) {
		if (l == r)
			return l;
		int mid = l + r >> 1;
		if (k <= val[ls[cur]])	
			return query(ls[cur], l, mid, k);
		else if (k <= val[cur])
			return query(rs[cur], mid + 1, r, k - val[ls[cur]]);
		else 
			return -1;
	}
	
} tr;

struct SAM {
	
	int sz, lst;
	int len[N], fa[N], ch[N][30];
	
	int pos[N], dep[N], rt[N], f[N][30];
	
	int cap;
	int head[N], to[N << 1], nxt[N << 1];
	
	void clear() {
		sz = lst = 1, cap = 0;
		memset(head, 0, sizeof head);
		memset(len, 0, sizeof len);
		memset(fa, 0, sizeof fa);
		memset(ch, 0, sizeof ch);
		memset(rt, 0, sizeof rt);
	}
	
	void add(int x, int y) {
		to[++cap] = y, nxt[cap] = head[x], head[x] = cap;	
	}
	
	void insert(int po, int c) {
		int cur = ++sz, p = lst;
		pos[po] = cur, len[cur] = po;
		for (; p && !ch[p][c]; p = fa[p])	ch[p][c] = cur;
		if (!p)	
			fa[cur] = 1;
		else {
			int q = ch[p][c];
			if (len[q] == len[p] + 1)	fa[cur] = q;
			else {
				int nq = ++sz;
				fa[nq] = fa[q], len[nq] = len[p] + 1;
				for (; p && ch[p][c] == q; p = fa[p])	ch[p][c] = nq;
				memcpy(ch[nq], ch[q], sizeof ch[q]);
				fa[q] = fa[cur] = nq;
			}
		}
		lst = cur;
		tr.mkchain(rt[cur], 1, n, po);
 	}
	
 	void DFS(int cur) {
 		for (int i = 1; i <= 20; ++i)	f[cur][i] = f[f[cur][i - 1]][i - 1];
	 		for (int i = head[cur]; i; i = nxt[i]) {
			f[to[i]][0] = cur;
			DFS(to[i]);
			rt[cur] = tr.merge(rt[cur], rt[to[i]]);
		}
 	}
 	
 	void link() {
 		for (int i = 2; i <= sz; ++i)	add(fa[i], i);
		DFS(1);	
 	}
 	
 	int solve(int l, int r, int k) {
 		int cur = pos[r];
 		for (int i = 20; i >= 0; --i) {
 			int p = f[cur][i];
			if (l + len[p] - 1 >= r)	cur = p;			
 		} 
 		int ans = tr.query(rt[cur], 1, n, k);
 		return (ans == -1) ? ans : ans - (r - l);
 	}
	
} sam;

int main() {
	
	scanf("%d", &T);
	
	while (T--) {
		
		scanf("%d%d", &n, &q);
		scanf("%s", a + 1);
		
		sam.clear(), tr.clear();
		for (int i = 1; i <= n; ++i)	sam.insert(i, a[i] - 'a' + 1);		
		
		sam.link();
		
		int l, r, k;
		while (q--) {
			scanf("%d%d%d", &l, &r, &k);
			printf("%d
", sam.solve(l, r, k));
		}
		
	}
	
	return 0;
}

原文地址:https://www.cnblogs.com/VeniVidiVici/p/11449473.html