[HAOI2016]找相同字符

Description

给定长度分别为 (n), (m) 的两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

(n,m le 2 imes 10^5)

Solution

(yyt)的题,考试时并不会后缀自动机,于是只能对一个串把所有后缀插入(AC)自动机,另一个串在上面跑,每次到一个节点就暴力跳(fail),记(cnt[x])为自动机(x)点被多少个后缀经过,把(fail)(cnt)加到答案里。

这样做正确性在于对B串所有后缀放入AC自动机后,A串在上面跑,到达一个点x表示A串的一个前缀的后缀和B串若干个后缀的前缀匹配上了(即若干个子串),由于是匹配的是最长的一段所以每次还要跳(fail),把(fail)的答案累计上,下面是核心代码:

for (int i = 1; i <= n; ++i)
{
    p = ch[p][A[i] - 'a'];
    for (int t = p; t; t = fail[t])
        ans += cnt[t];
}

然而这样空间爆炸(时间也爆炸,但在超时前空间已经爆了),于是就用后缀自动记来接受所有后缀。

后缀自动机的做法也是依赖上面的原理的,对B串建好SAM后A串仍然在B串的SAM上跑,到达一个点x代表A的一段前缀的后缀和B的若干个子串匹配上了我们记录当前的匹配的串长L,那么我们需要统计x接受了多少长度小于等于L的子串,那么就是 parent树上从x的父亲到根路径上所有点的接受串种类数×接受串的个数(endpos集合大小)之和+(L-x父亲的maxlen)×点x接受串个数(注意由于L可能会把x接受的后缀劈开,所以要单独加一下)。

前者预处理即可。

再提供一种 (SA) 的做法:

不难发现答案=A+B串自己匹配自己个数-A串自己匹配自己的个数-B串自己匹配自己的个数

而一个串匹配自己的个数就是枚举两个后缀然后对它们lcp求和,即:

[sum_{i=1}^nsum_{j=i}^nlcp(i,j) ]

也就是后缀排序后每两个后缀之间 (height) 的最小值:

[sum_{i=1}^nsum_{j=i}^nmin{height[icdots j]} ]

(height) 求出来后这就是一个很经典的问题了。

所以在A串和B串间放一个分隔符(比如$) ,后缀排序再算上面那个东西即可。

Code(SAM)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <fstream>

typedef long long LL;
typedef unsigned long long uLL;

#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;

using namespace std;

inline void proc_status()
{
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>()) << endl;
}

template<class T> inline T read() 
{
	register int x = 0; register int f = 1; register char c;
	while (!isdigit(c = getchar())) if (c == '-') f = -1;
	while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
	return x * f;
}

template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

const int maxN = (int) 2e5;

namespace SAM
{
	int last, Ncnt, size[maxN * 2];
	LL sum[maxN * 2];

	struct Status
	{
		int len, link;
		int ch[26];
	} st[maxN * 2]; 

	void init()
	{
		last = 0;
		st[0].link = -1;
		st[0].len = 0;
	}

	void insert(char ch)
	{
		int c = ch - 'a';
		int cur = ++Ncnt;
		int p = last;
		st[cur].len = st[p].len + 1;
		while (p != -1 and !st[p].ch[c])
		{
			st[p].ch[c] = cur;
			p = st[p].link;
		}
		if (p == -1)
			st[cur].link = 0;
		else 
		{
			int q = st[p].ch[c];
			if (st[q].len == st[p].len + 1)
				st[cur].link = q;
			else 
			{
				int clone = ++Ncnt;
				st[clone] = st[q];
				st[clone].len = st[p].len + 1;
				while (p != -1 and st[p].ch[c] == q)
				{
					st[p].ch[c] = clone;
					p = st[p].link;
				}
				st[q].link = st[cur].link = clone;
			}
		}
		last = cur;
		size[cur] = 1;
	}

	void debug(int x)
	{
		printf("%d link is %d
", x, st[x].link);
		for (int i = 0; i < 26; ++i)
			if (st[x].ch[i])
				printf("%d to %d %c
", x, st[x].ch[i], i + 'a');
		puts("-----------");
	}
}
using namespace SAM;

int n, m;
char A[maxN + 2], B[maxN + 2];

void Input() { scanf("%s%s", A + 1, B + 1); }

void Init()
{
	n = strlen(A + 1), m = strlen(B + 1);
	init();
	for (register int i = 1; i <= m; ++i)
		insert(B[i]);

	static int buc[maxN * 2+ 2], rk[maxN * 2 + 2];

	for (register int i = 1; i <= Ncnt; ++i) ++buc[st[i].len];
	for (register int i = 1; i <= Ncnt; ++i) buc[i] += buc[i - 1];
	for (register int i = 1; i <= Ncnt; ++i) rk[buc[st[i].len]--] = i;


	for (register int i = Ncnt; i >= 1; --i)
	{
		int p = rk[i];
		if (p) size[st[p].link] += size[p];
	}
	for (register int i = 1; i <= Ncnt; ++i)
	{
		int p = rk[i];
		if (p) sum[p] = sum[st[p].link] + 1ll * size[p] * (st[p].len - st[st[p].link].len);
	}
}

void Solve()
{
	LL ans = 0;
	int cur = 0, L = 0;
	for (register int i = 1; i <= n; ++i)
	{
		int c = A[i] - 'a';
		while (cur != -1 and !st[cur].ch[c])
		{
			cur = st[cur].link;
			if (cur != -1)
				L = st[cur].len;
		}
		if (cur != -1)
		{
			L++;
			cur = st[cur].ch[c];
			ans += sum[st[cur].link] + (L - st[st[cur].link].len) * size[cur];
		}
		else 
			cur = 0;
	}
	printf("%lld
", ans);
}

int main() 
{
	Input();
	Init();
	Solve();
	return 0;
}

Code(SA)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <fstream>

typedef long long LL;
typedef unsigned long long uLL;

#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;

using namespace std;

inline void proc_status()
{
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>()) << endl;
}

template<class T> inline T read() 
{
	register T x(0);
	register char c;
	register int f(1);
	while (!isdigit(c = getchar())) if (c == '-') f = -1;
	while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
	return x * f;
}

template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

const int maxN = 2e6 + 2;

namespace SA
{
	int ht[maxN + 2], n;
	int tmp[maxN + 2], sa[maxN + 2], rk[maxN + 2], M;

	void Rsort()
	{
		static int buc[maxN + 2];

		fill(buc, buc + 1 + M, 0);
		for (register int i = 1; i <= n; ++i) ++buc[rk[i]];
		for (register int i = 1; i <= M; ++i) buc[i] += buc[i - 1];
		for (register int i = n; i >= 1; --i) sa[buc[rk[tmp[i]]]--] = tmp[i];
	}

	void Build(char str[])
	{//这里记得清空。
		fill(ht + 1, ht + 1 + n, 0);
		fill(sa + 1, sa + 1 + n, 0);
		fill(rk + 1, rk + 1 + n, 0);
		fill(tmp + 1, tmp + 1 + n, 0);
		n = strlen(str + 1), M = 230;
		for (register int i = 1; i <= n; ++i)
			rk[i] = str[i], tmp[i] = i;
		Rsort();
		for (int w = 1, cnt = 0; cnt < n; w <<= 1, M = cnt)
		{
			cnt = 0;
			for (register int i = n - w + 1; i <= n; ++i) tmp[++cnt] = i;
			for (register int i = 1; i <= n; ++i) if (sa[i] > w) tmp[++cnt] = sa[i] - w;
			Rsort(); swap(rk, tmp);
			rk[sa[1]] = cnt = 1;
			for (register int i = 2; i <= n; ++i)
				rk[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w]) ? cnt : ++cnt;
		}
		for (int i = 1, k = 0; i <= n; ++i)
		{
			if (k) k--;
			int j = sa[rk[i] - 1];
			while (j + k <= n and i + k <= n and str[i + k] == str[j + k]) k++;
			ht[rk[i]] = k;
		}
	}
}

LL solve(int n, int a[])
{
	LL ans(0);
	static int R[maxN + 2], L[maxN + 2];
	static int rk[maxN + 2], buc[maxN + 2];

	fill(buc + 0, buc + n + 1, 0);
	for (register int i = 1; i <= n; ++i) ++buc[a[i]];
	for (register int i = 1; i <= n; ++i) buc[i] += buc[i - 1];
	for (register int i = 1; i <= n; ++i) rk[buc[a[i]]--] = i;

	for (int i = 1; i <= n; ++i) L[i] = i - 1, R[i] = i + 1;
	for (int i = n; i >= 1; --i)
	{
		int p = rk[i];
		ans += (LL) a[p] * (R[p] - p) * (p - L[p]);
		L[R[p]] = L[p];
		R[L[p]] = R[p];
		L[p] = R[p] = 0;
	}
	return ans;
}

LL work(char str[])
{
	SA::Build(str);
	int n = strlen(str + 1);
	return solve(n, SA::ht);
}

char s1[maxN + 2], s2[maxN + 2], s3[maxN << 1];

void Input()
{
	scanf("%s", s1 + 1);
	scanf("%s", s2 + 1);
}

void Solve()
{
	LL ans = 0;
	static char s3[maxN + 2];

	int len1 = strlen(s1 + 1), len2 = strlen(s2 + 1);
	for (register int i = 1; i <= len1; ++i) s3[i] = s1[i];
	s3[len1 + 1] = '$';
	for (register int i = 1; i <= len2; ++i) s3[i + len1 + 1] = s2[i];
	ans += work(s3);
	ans -= work(s1);
	ans -= work(s2);
	printf("%lld
", ans);
}

int main() 
{
#ifndef ONLINE_JUDGE
	freopen("P3181.in", "r", stdin);
	freopen("P3181.out", "w", stdout);
#endif
	Input();
	Solve();
	return 0;
}
原文地址:https://www.cnblogs.com/cnyali-Tea/p/11478131.html