【hihoCoder 1466】后缀自动机六·重复旋律9

http://hihocoder.com/problemset/problem/1466
建出A串和B串的两个后缀自动机
对后缀自动机的每个状态求出sg值。
求出B串的(sum(x)),表示B有多少子串的sg值等于x(用拓扑序求)。
对A串的每个状态,求出B串有多少子串的sg值不等于这个状态的sg值,再按拓扑序递推一下。
接下来就类似SPOJ 7258这道题了
从A串开始走,按字典序从小到大,定住A串后,根据在A串停住的状态的sg值再在B串上按拓扑序递推一次求出当前状态往后可以走出多少不等于这个sg值的子串,再在B串上按字典序从小到大走定住B串。
注意空串也算子串。
时间复杂度(O(nlog n)),只有求sg函数排序是(O(nlog n))的,其他操作都是(O(n))的。
调了好几天,很恶心啊,把c[nn + 1] = -1;打成c[nn + 1] == -1;了。

要是开-Wall就没这种事了qwq
在周赛结束前10分钟才发现错误,然后改过来A了233

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 100003;

int tot = 0, cnt, cnt2;
struct State {
	State *par, *go[26];
	int val, sg; ll f;
} pool[N << 2], *id[N << 1], *tp[N << 1], *tp2[N << 1], *root_A, *root_B, *root, *last, *tmp;

State *newState(int num) {
	pool[++tot].val = num;
	pool[tot].par = 0;
	pool[tot].sg = 0;
	pool[tot].f = 0;
	memset(pool[tot].go, 0, sizeof(pool[tot].go));
	return id[++cnt] = &pool[tot];
}

void extend(int w) {
	State *p = last;
	State *np = newState(p->val + 1);
	while (p && p->go[w] == 0)
		p->go[w] = np, p = p->par;
	if (p == 0) np->par = root;
	else {
		State *q = p->go[w];
		if (q->val == p->val + 1) np->par = q;
		else {
			State *nq = newState(p->val + 1);
			memcpy(nq->go, q->go, sizeof(nq->go));
			nq->par = q->par;
			q->par = np->par = nq;
			while (p && p->go[w] == q)
				p->go[w] = nq, p = p->par;
		}
	}
	last = np;
}

char Sa[N], Sb[N], ansa[N], ansb[N];
int nn, c[N << 1], ansalen = 0, ansblen = 0;
ll sum[N], k;

void pre(int len) {
	cnt2 = cnt;
	for (int i = 1; i <= cnt; ++i) ++c[id[i]->val];
	for (int i = 1; i <= len; ++i) c[i] += c[i - 1];
	for (int i = cnt; i >= 1; --i) tp[c[id[i]->val]--] = id[i];
	
	for (int i = cnt; i >= 1; --i) {
		tp2[i] = tmp = tp[i];
		
		nn = 0;
		for (int w = 0; w < 26; ++w)
			if (tmp->go[w]) {
				tmp->f += tmp->go[w]->f;
				c[++nn] = tmp->go[w]->sg;
			}
		++tmp->f;
		
		stable_sort(c + 1, c + nn + 1);
		if (c[1] != 0 || nn == 0) tmp->sg = 0;
		else {
			c[nn + 1] = -1;
			for (int j = 1; j <= nn; ++j)
				if (c[j] != c[j + 1] && c[j] + 1 != c[j + 1]) {
					tmp->sg = c[j] + 1;
					break;
				}
		}
	}
}

void pre2(int len) {
	memset(c, 0, sizeof(int) * (len + 1));
	for (int i = 1; i <= cnt; ++i) ++c[id[i]->val];
	for (int i = 1; i <= len; ++i) c[i] += c[i - 1];
	for (int i = cnt; i >= 1; --i) tp[c[id[i]->val]--] = id[i];
	
	for (int i = cnt; i >= 1; --i) {
		tmp = tp[i];
		nn = 0;
		for (int w = 0; w < 26; ++w)
			if (tmp->go[w]) {
				tmp->f += tmp->go[w]->f;
				c[++nn] = tmp->go[w]->sg;
			}
		
		stable_sort(c + 1, c + nn + 1);
		if (c[1] != 0 || nn == 0) tmp->sg = 0;
		else {
			c[nn + 1] = -1;
			for (int j = 1; j <= nn; ++j)
				if (c[j] != c[j + 1] && c[j] + 1 != c[j + 1]) {
					tmp->sg = c[j] + 1;
					break;
				}
		}
		
		tmp->f += sum[tmp->sg];
	}
}

void work_B(int nu) {
	for (int i = cnt2; i >= 1; --i) {
		tmp = tp2[i]; tmp->f = 0;
		for (int w = 0; w < 26; ++w)
			if (tmp->go[w])
				tmp->f += tmp->go[w]->f;
		if (tmp->sg != nu) ++tmp->f;
	}
	
	tmp = root_B;
	bool flag;
	while (k) {
		flag = false;
		if (tmp->sg != nu) --k;
		if (k == 0) {flag = true; break;}
		for (int w = 0; w < 26; ++w)
			if (tmp->go[w] && k)
				if (tmp->go[w]->f >= k) {
					flag = true;
					tmp = tmp->go[w];
					ansb[++ansblen] = 'a' + w;
					break;
				} else
					k -= tmp->go[w]->f;
		if (!flag) break;
	}
	
	if (!flag) puts("NO");
	else {
		for (int i = 1; i <= ansalen; ++i) putchar(ansa[i]); puts("");
		for (int i = 1; i <= ansblen; ++i) putchar(ansb[i]); puts("");
	}
}

int main() {
	scanf("%lld%s%s", &k, Sa + 1, Sb + 1);
	int lena = strlen(Sa + 1), lenb = strlen(Sb + 1);
	
	cnt = 0;
	root_B = root = last = newState(0);
	for (int i = 1; i <= lenb; ++i)
		extend(Sb[i] - 'a');
	pre(lenb);
	for (int i = 1; i <= cnt; ++i)
		if (tp[i] != root_B) sum[tp[i]->sg] += tp[i]->val - tp[i]->par->val;
		else ++sum[tp[i]->sg];
	for (int i = 0; i <= lena; ++i)
		sum[i] = root_B->f - sum[i];
	
	cnt = 0;
	root_A = root = last = newState(0);
	for (int i = 1; i <= lena; ++i)
		extend(Sa[i] - 'a');
	pre2(lena);
	
	tmp = root_A;
	bool flag;
	while (k) {
		flag = false;
		if (sum[tmp->sg] >= k) {
			work_B(tmp->sg);
			return 0;
		}
		k -= sum[tmp->sg];
		for (int w = 0; w < 26; ++w)
			if (tmp->go[w] && k)
				if (tmp->go[w]->f >= k) {
					flag = true;
					ansa[++ansalen] = 'a' + w;
					tmp = tmp->go[w];
					break;
				} else
					k -= tmp->go[w]->f;
		if (!flag) break;
	}
	
	puts("NO");
	return 0;
}
原文地址:https://www.cnblogs.com/abclzr/p/6286044.html