SPOJ 1812

题目概述

(n)个字符串的最长公共子串。(nle 10)

题解

对于最短字符串建立SAM。
考虑将其他字符串放到这个SAM上匹配,并记录mx[i]表示以i为结尾的字符串最长匹配长度。
从根开始向后匹配。如果自动机上可以向后匹配字符,就走一条转一边,转移到下一个结点;否则跳parent树,知道存在这个字符的转移边为止。
如果永远没有,那么从根开始重新匹配。每匹配到一个节点,就更新mx[i]。
匹配完再对于每一个mx[fa[i]]从mx[i]更新。这是因为一个节点的par必然在他的儿子中出现,但却有可能没有匹配到par这个节点。
容易证明这样每个结点都有最长匹配值。
再记录mn[i]表示i结点在每一个字符串中的最长匹配长度的最小值。
那么所有结点的mn[i]就是答案。

注:之所以选取最短字符串建立SAM是因为每次匹配后都要遍历整个SAM并更新SAM每个节点的mx值,如果不选取最短串有可能复杂度并不是(mathcal O(sum s_i))

代码

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int N = 1e5 + 5, NODE = 2 * N;
int n, len[11], mnId, mx[NODE];
char str[11][N];
struct SAM {
	int ch[NODE][26], len[NODE], fa[NODE], mx[NODE], las, t, buc[N], od[NODE], mn[NODE];
	SAM() { las = t = 1; }
	void ins(int c) {
		int p = las, np;
		fa[las = np = ++t] = 1;
		len[t] = len[p] + 1;
		for (; p && !ch[p][c]; p = fa[p])
			ch[p][c] = np;
		if (p) {
			int q = ch[p][c], nq;
			if (len[p] + 1 == len[q])
				fa[np] = q;
			else {
				fa[nq = ++t] = fa[q];
				len[t] = len[p] + 1;
				memcpy(ch[nq], ch[q], sizeof(ch[nq]));
				for (fa[np] = fa[q] = nq; ch[p][c] == q; p = fa[p])
					ch[p][c] = nq;
			}
		}
	}
	void init(int le) {
		for (int i = 1; i <= t; ++i)
			++buc[len[i]];
		for (int i = 1; i <= le; ++i) //这里 n并不是字符串长度... 
			buc[i] += buc[i - 1];
		for (int i = 1; i <= t; ++i)
			od[buc[len[i]]--] = i;
		for (int i = 1; i <= t; ++i)
			mn[i] = 0x3f3f3f3f;
	}
	void calc(char *s, int n) {
		int now = 1, l = 0;
		for (int i = 1; i <= t; ++i)
			mx[i] = 0;
		for (int i = 1; i <= n; ++i) {
			while (now && !ch[now][s[i] - 'a'])
				now = fa[now], l = len[now];
			if (now) {
				now = ch[now][s[i] - 'a'];
				mx[now] = max(mx[now], ++l);
			} else {
				now = 1;
				l = 0;
			}
		}
		for (int i = 1; i <= t; ++i) {
			mx[fa[od[i]]] = min(max(mx[fa[od[i]]], mx[od[i]]), len[fa[od[i]]]);
			mn[od[i]] = min(mn[od[i]], mx[od[i]]);
		}
	}
} sam;
int main() {
	mnId = 1;
	while (~scanf("%s", str[++n] + 1)) {
		len[n] = strlen(str[n] + 1);
		if (len[mnId] > len[n])
			mnId = n;
	}
	--n;
	for (int i = 1; i <= len[mnId]; ++i)
		sam.ins(str[mnId][i] - 'a');
	sam.init(len[mnId]);
	for (int i = 1; i <= n; ++i)
		if (i != mnId)
			sam.calc(str[i], len[i]);
	int ans = 0;
	for (int i = 1; i <= sam.t; ++i)
		ans = max(ans, sam.mn[i]);
	printf("%d
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/skiceanAKacniu/p/13113487.html