[HNOI2019]JOJO

题目链接
题意:有一个字符串,初始为空,n次操作,每次可以添加一段字符到末尾或回到第x次操作之后的状态。每次操作后,输出所有前缀的next之和。
首先,那个撤销操作可以离线,建版本树解决。
既然求的是最长公共前后缀,自然想到用KMP。
模仿KMP的过程:记录每次添加后的串的next,中间的next无需记录。
然后,考虑添加一段字符:
和KMP一样,沿next链走,不过这里的匹配需要匹配一串字符,所以求出从(i)位置向后,能匹配(c)多少。
当这个长度增加时,对答案的贡献是一段等差数列。
若一个位置向后,正好能匹配(x)个,则将这个位置赋给添加后的next。
注意:如果匹配的个数大于(x)个,则不要做此操作。
理由:因为匹配的个数(x)个,而题目保证相邻两次的(c)不同,所以匹配到中间是无法往下匹配的。
所以,不会有记录的next指向中间,中间的next就无需记录。
这样,就简单多了。
但是,KMP的复杂度是均摊的,不能可持久化(会被卡超时)。
考虑一个优化:
若最长公共前后缀的长度大于串的一半,则说明出现了循环节,那样可以直接跳到第一个循环节的对应位置。
这样,无论如何,走一次next,长度至少会缩减一半。
这样,时间复杂度就是严格的(O(nlogn))的了。
详见代码:

#include <stdio.h> 
#define md 998244353 
struct SJd {
	int l,r,c;
};
int getsum(int x, int y) {
	return 1ll * (x - y + 1 + x) * y / 2 % md;
}
struct String {
	SJd sz[100010];
	int ne[100010],cd;
	String() {
		cd = 0;
		sz[0].l = sz[0].r = 0;
		ne[0] = -1;
	}
	int getne(int x) {
		if (ne[x] == -1 || sz[ne[x]].r * 2 <= sz[x].r + 1) return ne[x];
		else return ne[x] % (x - ne[x]);
	}
	int insert(int x, int c) {
		cd += 1;
		sz[cd].l = sz[cd - 1].r + 1;
		sz[cd].r = sz[cd].l + x - 1;
		sz[cd].c = c;
		ne[cd] = 0;
		if (cd == 1) return getsum(x - 1, x);
		int p = getne(cd - 1),ma = 0,rt = 0,
		tp = ne[cd - 1];
		while (p != -1) {
			int t = (sz[tp + 1].c == c ? sz[tp + 1].r - sz[tp + 1].l + 1 : 0);
			bool b = false;
			if (t > x) {
				t = x;
				b = true;
			}
			if (t > ma) {
				rt = (rt + getsum(sz[tp].r + t, t - ma)) % md;
				ma = t;
			}
			if (t == x && !b) {
				ne[cd] = tp + 1;
				break;
			}
			p = getne(p);
			tp = ne[tp];
		}
		if (ma < x && sz[1].c == c) {
			rt = (rt + 1ll * (x - ma) * (sz[1].r - sz[1].l + 1)) % md;
			ne[cd] = 1;
		}
		return rt;
	}
};
String str;
int tm[100010],ans[100010];
int fr[100010],ne[100010],v[100010],bs = 0;
int lx[100010],x[100010],c[100010];
void addb(int a, int b) {
	v[bs] = b;
	ne[bs] = fr[a];
	fr[a] = bs++;
}
void dfs(int u, int he) {
	int oldcd = str.cd;
	if (u != 0) he = (he + str.insert(x[u], c[u])) % md;
	ans[u] = he;
	for (int i = fr[u]; i != -1; i = ne[i]) dfs(v[i], he);
	str.cd = oldcd;
}
int main() {
	int n;
	scanf("%d", &n);
	for (int i = 0; i <= n; i++) fr[i] = -1;
	for (int i = 1; i <= n; i++) {
		int lx;
		scanf("%d", &lx);
		if (lx == 1) {
			char ch[2];
			scanf("%d%s", &x[i], ch);
			c[i] = ch[0] - 'a';
			addb(tm[i - 1], i);
			tm[i] = i;
		} else {
			int a;
			scanf("%d", &a);
			tm[i] = tm[a];
		}
	}
	dfs(0, 0);
	for (int i = 1; i <= n; i++) {
		if (lx[i] == 1) printf("%d
", ans[i]);
		else printf("%d
", ans[tm[i]]);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lnzwz/p/11348426.html