noip2017集训测试赛(十一)Problem C: 循环移位

题面

Description

给定一个字符串 ss 。现在问你有多少个本质不同的 ss 的子串 t=t1t2⋯tm(m>0)t=t1t2⋯tm(m>0) 使得将 tt 循环左移一位后变成的 t′=t2⋯tmt1t′=t2⋯tmt1 也是 ss 的一个子串。

Input

输入仅有一行,一个字符串 s(1≤lens≤300000)s(1≤lens≤300000) 。字符串 ss 仅包含小写字母。

Output

输出一个整数表示答案。

Sample Input

(样例输入1)
abaac
(样例输入2)
aaa

Sample Output

(样例输出1)
7
(样例输出2)
3

HINT

(样例解释)

第一组数据:符合条件的字符串 tt 有: a, b, c, aa, ab, ba, aba

第二组数据:符合条件的字符串 tt 有: a, aa, aaa

(数据范围与约定)

子任务1(10分): 1≤lens≤2001≤lens≤200

子任务2(30分): 1≤lens≤50001≤lens≤5000

子任务3(60分): 1≤lens≤300000

Solution

这题不算特别难, 但确实是一道好题.

一般而言, 后缀自动机的题目只需要用到后缀树上的连边, 但这题既用了后缀树的边, 又用了后缀自动机上的边.

题目的本质是要我们求出有多少组这样的((s, c)), 其中(s)为字符串, (c)为字符, 使得(sc)(cs)都是原串的子串.

考虑枚举原串的每个子串(s), 再枚举每个字符(c), 则我们只需要判定(cs)(sc)是否都是原串的子串即可.

考虑如何枚举原串的每个子串(s), 不难想到用后缀树; (cs)可以通过后缀树上记录每个节点中包含的每个字符数量以及是否有(c)这个儿子来统计; (sc)则只需要判断一个节点在后缀自动机上是否有(c)这个后继即可.

口胡了这么多, 总之, 就是用后缀树上的边来找前缀, 后缀自动机上的边来找后缀.

#include <cstdio>
#include <cstring>

typedef long long LL;
const int N = 5000, K = 47, MOD = (int)1e9 + 7;
int pw[N + 7], pwInv[N + 7];
int a[N + 7], f[N + 7][N + 7], sum[N + 7][N + 7], hsh[N + 7];
inline int getInverse(int a)
{
	int res = 1;
	for (int x = MOD - 2; x; x >>= 1, a = (LL)a * a % MOD) if (x & 1) res = (LL)res * a % MOD;
	return res;
}
inline int getHash(int L, int R)
{
	return (LL)(hsh[R] - hsh[L - 1] + MOD) * pwInv[L] % MOD;
}
int main()
{

#ifndef ONLINE_JUDGE

	freopen("sequence.in", "r", stdin);
	freopen("sequence.out", "w", stdout);
	
#endif

	int n; scanf("%d
", &n);
	pw[0] = 1; for (int i = 1; i <= n; ++ i) pw[i] = (LL)pw[i - 1] * K % MOD, pwInv[i] = getInverse(pw[i]);
	hsh[0] = 0;
	for (int i = 1; i <= n; ++ i) a[i] = getchar() - '0', hsh[i] = (hsh[i - 1] + (LL)a[i] * pw[i] % MOD) % MOD;
	memset(f, 0, sizeof f); memset(sum, 0, sizeof sum);
	f[0][0] = 1; for (int i = 0; i <= n; ++ i) sum[0][i] = 1;
	for (int i = 1; i <= n; ++ i) 
	{
		for (int j = 1; j <= i; ++ j)
		{
			if (a[i - j + 1] == 0) continue;
			f[i][j] = sum[i - j][j - 1];
			if (j <= i - j && getHash(i - j - j + 1, i - j) != getHash(i - j + 1, i))
			{
				int L = 1, R = j, p;
				while (L <= R)
				{
					int mid = L + R >> 1;
					if (getHash(i - j - j + 1, i - j - j + mid) != getHash(i - j + 1, i - j + mid)) p = mid, R = mid - 1;
					else L = mid + 1;
				}
				if (a[i - j + p] > a[i - j - j + p]) f[i][j] = (f[i][j] + f[i - j][j]) % MOD;
			}
		}
		sum[i][0] = 0;
		for (int j = 1; j <= n; ++ j) sum[i][j] = (sum[i][j - 1] + f[i][j]) % MOD;
	}
	printf("%d
", sum[n][n]);
}
原文地址:https://www.cnblogs.com/ZeonfaiHo/p/7649868.html