优秀的拆分「NOI2016」

题目描述

如果一个字符串可以被拆分为 ( ext{AABB}) 的形式,其中 ( ext{A})( ext{B}) 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。
例如,对于字符串 ( exttt{aabaabaa}) ,如果令 ( ext{A}= exttt{aab})( ext{B}= exttt{a}),我们就找到了这个字符串拆分成 ( ext{AABB}) 的一种方式。

一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。
比如我们令 ( ext{A}= exttt{a})( ext{B}= exttt{baa}),也可以用 ( ext{AABB}) 表示出上述字符串;但是,字符串 ( exttt{abaabaa}) 就没有优秀的拆分。

现在给出一个长度为 (n) 的字符串 (S),我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。

以下事项需要注意:

  1. 出现在不同位置的相同子串,我们认为是不同的子串,它们的优秀拆分均会被记入答案。
  2. 在一个拆分中,允许出现 ( ext{A}= ext{B})。例如 ( exttt{cccc}) 存在拆分 ( ext{A}= ext{B}= exttt{c})
  3. 字符串本身也是它的一个子串。

输入格式

每个输入文件包含多组数据。
输入文件的第一行只有一个整数 (T),表示数据的组数。
接下来 (T) 行,每行包含一个仅由英文小写字母构成的字符串 (S),意义如题所述。

输出格式

输出 (T) 行,每行包含一个整数,表示字符串 (S) 所有子串的所有拆分中,总共有多少个是优秀的拆分。

(nle 30000)

题解

太良心了
(85\%)的点(nle 500),直接(O(n^3))暴力枚举区间+断点用哈希判断

然后只要稍微动动脑子:设(a[i])表示以(i)结尾的( ext{AA})串个数,(b[i])表示以(i)开头的( ext{AA})串个数,那么答案就是(sumlimits_{i=1}^{n-1} a[i]*b[i+1])

(O(n^2)) 95分到手

最后五分如果想不出来不拿也感觉无所谓。。。最后五分确实不好想

所以开始说正解:

上面的95分解法问题就在于(a[N], b[N]),我们需要(O(n^2))的时间求出来,考虑怎么样求得更快

我们枚举一个(len)表示我们现在想找到那些长度为(2*len)( ext{AA})

然后在原串上每隔(len)放一个断点

我们枚举相邻的两个断点(i,j),现在我们想要知道 以(i)开头的后缀与以(j)开头的后缀的最长公共前缀(LCP) 和 以(i)结尾的前缀与以(j)结尾的前缀的最长公共后缀(LCS)

LCP可以用后缀数组求;LCS也可以,把原数组翻转之后就变成后缀的LCP了,所以这两个都是可以用ST表(O(1))求出的

那么现在我们求出了这两个值

情况1

对于这种情况,即(LCP+LCS-1<len),我们是找不出( ext{AA})串的

情况2

用脚画图 不愧是我

(LCP+LCS-1<len),这个时候就有很多的长为(2*len)( ext{AA})串了,图中画出的( ext{AA}, ext{BB})就是最靠左和最靠右的两个这样的串

实际上,我画了"OK"的那个橙色区间的每一个点都是一个长为(2*len)( ext{AA})串的开头,应该很好理解吧。。。

如何找哪一段是合法( ext{AA})串的结尾也同理

所以实际上每次就是把(a[N])(b[N])的某一段全部加一 用差分来维护一下就行了

最后来看一下时间复杂度

后缀数组+ST表是(O(nlog n))

(frac{n}{1}+frac{n}{2}+frac{n}{3}+dots+frac{n}{n}) 我记得差不多就是(O(n log n))吧。。。可能要稍微大一点

总之(nle 30000)的数据是完全没有压力的

注意多组数据初始化数组!注意多组数据初始化数组!注意多组数据初始化数组!

代码

#include <bits/stdc++.h>
#define N 60005
using namespace std;

int t, n, nn;
char s[N];
int a[N], b[N];
int sa[N], sa2[N], rnk[N], sum[N], key[N], height[N], ST[N][21]; 

inline bool check(int *num, int aa, int bb, int l) {
	if (aa + l > n || bb + l > n) return false;  //多组数据,一定要加!
	return num[aa] == num[bb] && num[aa+l] == num[bb+l];
}
 
void DA() {
	int i, j, p, m = 128;
	for (i = 1; i <= m; i++) sum[i] = 0;
	for (i = 1; i <= n; i++) sum[rnk[i]=s[i]]++;
	for (i = 2; i <= m; i++) sum[i] += sum[i-1];
	for (i = n; i; i--) sa[sum[rnk[i]]--] = i;
	for (j = 1; j <= n; j <<= 1, m = p) {
		for (p = 0, i = n - j + 1; i <= n; i++) sa2[++p] = i;
		for (i = 1; i <= n; i++) if (sa[i] - j > 0) sa2[++p] = sa[i] - j;
		for (i = 1; i <= n; i++) key[i] = rnk[sa2[i]];
		for (i = 1; i <= m; i++) sum[i] = 0;
		for (i = 1; i <= n; i++) sum[key[i]]++;
		for (i = 2; i <= m; i++) sum[i] += sum[i-1];
		for (i = n; i; i--) sa[sum[key[i]]--] = sa2[i];
		for (swap(sa2, rnk), p = 2, rnk[sa[1]] = 1, i = 2; i <= n; i++) {
			rnk[sa[i]] = check(sa2, sa[i-1], sa[i], j) ? p - 1 : p++;
		}
	} 
}

void geth() {
	int p = 0;
	for (int i = 1; i <= n; i++) rnk[sa[i]] = i;
	for (int i = 1; i <= n; i++) {
		if (p) p--;
		int j = sa[rnk[i]-1];
		while (s[i+p] == s[j+p] && i + p <= n && j + p <= n) p++; //多组数据,一定要加!
		height[rnk[i]] = p;
	}
}

void preST() {
	for (int i = 1; i <= n; i++) ST[i][0] = height[i];
	for (int l = 1; (1 << l) <= n; l++) {
		for (int i = 1; i + (1<<l) - 1 <= n; i++) {
			ST[i][l] = min(ST[i][l-1], ST[i+(1<<(l-1))][l-1]);
		}
	}
} 

inline int QST(int x, int y) {
	if (x > y) swap(x, y); x++;
	int l = log2(y - x + 1);
	return min(ST[x][l], ST[y-(1<<l)+1][l]);
}

inline int LCP(int x, int y) { return QST(rnk[x], rnk[y]); }
inline int LCS(int x, int y) { return QST(rnk[n-x+1], rnk[n-y+1]); }

void Solve() {
	for (int l = 1; l * 2 <= nn; l++) {
		for (int i = 1, j = i + 1; j * l <= nn; i++, j++) {
			int lcp = min(LCP(i*l, j*l), l), lcs = min(LCS(i*l, j*l), l);
			if (lcp + lcs - 1 >= l) {
				a[j*l+l-lcs]++; a[j*l+lcp]--;
				b[i*l-lcs+1]++; b[i*l-l+lcp+1]--;
			} 
		}
	}
	for (int i = 1; i <= nn; i++) {
		a[i] += a[i-1];
		b[i] += b[i-1];
	}
	long long ans = 0;
	for (int i = 1; i < nn; i++) {
		ans += 1ll * a[i] * b[i+1];
	}
	printf("%lld
", ans);
} 

int main() {
	scanf("%d", &t);
	while (t--) {
		memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));
		scanf("%s", s + 1);
		n = strlen(s + 1); 
		s[n+1] = '$';
		for (int i = n + 2; i <= 2 * n + 1; i++) {
			s[i] = s[2 * n - i + 2];
		}
		nn = n;
		n = (n<<1|1);
		DA(); geth(); preST();
		Solve();
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ak-dream/p/AK_DREAM83.html