The Problem to Slow Down You(Palindromic Tree)

题目链接:http://codeforces.com/gym/100548

今天晚上突然有了些兴致去学习一下数据结构,然后就各种无意中看到了Palindrome Tree的数据结构,据说是2014年新出的数据结构,也让我回想起了西安打铁时候的经历。这道题的题意其实是比较清晰的,给你两个长度200000的字符串,求它们有多少对回文子串。处理字符串有许多常用的工具,像后缀数组,后缀自动机,AC自动机,KMP,但是这些数据结构在针对回文串的处理上都不是特别强,当然稍微处理一下我们还是可以很好的利用好上面这些数据结构处理一些回文的问题。另外一个关于回文的问题就是求某个串的最长公共子串了,一个O(n)的Manachar算法算是解决了这个问题。Palindrome Tree似乎就是专门为了填补对于处理回文串的空白而产生的,而且最神的是,当你有了前面的一些数据结构的基础,再去看这个数据结构的整个结构,算法的流程,你会觉得很清晰,非常容易理解,当然这也是多亏了下面的博客给了我代码的模板以及对数据结构的一些理解。

http://blog.csdn.net/u013368721/article/details/42100363

http://blog.csdn.net/u013368721/article/details/42104207

下面是谈谈自己对上面博客阐述内容的一些个人的理解

首先是对数据结构的成员作一个简单的描述:
nxt[i][ch]   节点i在接收了字符ch之后所指向的节点
fail[i]      以节点i所指代的回文串的末尾字符结束的上一个回文串 fail[ ababa ] = aba  fail[aba] = a
cnt[i]       节点i所表示的回文串的个数
num[i]       以节点i所表示的回文串的末尾字符结束的回文串的个数
len[i]       节点i所表示的回文串的长度
S[i]         第i个字符
last         新添加的字符所产生的最长回文串的节点
p            节点的总的个数
n            当前字符串的长度

其中个人觉得比较难理解的是fail[i]. 假如i节点指代的回文串是"ababa"的话,那么fail[i]就是指以这个回文串的末尾字符结尾(就是'a')的上一个最长的回文串,也即是aba,再上一个自然就是a了。相似的是KMP里的失配指针,或者是在后缀自动机里上一个可以接收的后缀。

接下来就是整个算法的描述了

1.一开始初始化0号节点和1号节点,分别表示长度为偶数的节点(即空串),和1号节点(长度为奇数的空串),显然空串的长度是0,而这样的奇数串是不存在的,len[0]=0,len[1]=-1. 这个len[1]设成-1有着其独道的好处,具体的一些好处可以参看上面的博客。

2.接下来考虑添加1个新字符ch,首先要注意的是getFail()函数的作用,while (S[n - len[x] - 1] != S[n]) x = fail[x]; 的实际含义是找到一个能够接收当前新字符的回文串的编号,  对原串"abacaba"新加1个c,S[n-len[x]-1]=a != c, 所以回到上一个回文串节点,即表示aba的节点,这个时候S[n-len[x]-1]=c,满足题意,于是乎我们就找到了一个可以接收这个新状态的节点cur

3.接下来我们就可以更新nxt[cur][ch],当这个节点不存在的时候,我们需要新建这样的节点,长度是len[cur]+2,然后我们还需要知道新建的这个节点now的fail值,fail[now]=nxt[getFail(fail[cur])][ch],即当前新建立的节点的上一个回文串,getFail(fail[cur])返回的是以fail[cur]对应字符串结尾的可以接收ch的回文串。所以可以这么理解 now=nxt[getFail(cur)][ch],fail[now]=nxt[getFail(fail[cur])][ch].

4.然后更新一下新节点的num值,以及cnt值,注意的是这里的cnt值并不代表该节点对应的回文子串出现的总次数,正如后缀自动机里记数的时候cnt也不能表示这个字符串出现的总次数,最后需要有一个结果回加的过程,即是count的过程。

最后整个代码是有着比较清晰的结构的,下面对于这题的代码就是用了上面博客提供的模板。

对于原本的题目,需要做的就是在偶串中心节点0,和奇串中心节点1,不停地往两边塞字符就好了,当两边的状态都存在的时候继续往下dfs,最后统计一下结果即可。

这个数据结构维护的信息其实可以用作很多特别的用途。

1.直接dfs(0),dfs(1)我们可以递归的打印出所有回文串以及各自出现的次数

2.新添加1个字符是否能产生新的回文串,即可以在线的得出S的前缀i包含的不同的回文子串的个数。

3.num[i]可以得到对每个新增的字符为结尾的回文串的个数,以及len[i]求出以这个字符结尾的最长回文串长度

总之个人觉得这个数据结构可以做绝大多数的回文串的题目了,而且这个数据结构非常的elegent.字符串长度为n,字符集大小为m,则空间复杂度是O(nm),至于时间复杂度的话,我也不知道是怎么证明的,但是据说是O(nlogm),可以说对于字符集经常固定的题目来说就是一个O(n)的算法了。

#pragma warning(disable:4996)
#include <iostream>
#include <cstring>
#include <string>
#include <vector>
#include <cstdio>
#include <algorithm>
using namespace std;

#define MAXN 210000
#define ll long long

struct PalindromeTree
{
	/*static variables*/
	const static int maxn = 210000;
	const static int ch = 26;
	/*
	 * nxt[i][ch]   节点i在接收了字符ch之后所指向的节点
	 * fail[i]      以节点i所指代的回文串的末尾字符结束的上一个回文串 fail[ ababa ] = aba  fail[aba] = a
	 * cnt[i]       节点i所表示的回文串的个数
	 * num[i]       以节点i所表示的回文串的末尾字符结束的回文串的个数
	 * len[i]       节点i所表示的回文串的长度
	 * S[i]         第i个字符
	 * last         新添加的字符所产生的最长回文串的节点
	 * p            节点的总的个数
	 * n            当前字符串的长度
	 */
	int nxt[maxn][ch];
	int fail[maxn];
	int cnt[maxn];
	int num[maxn];
	int len[maxn];
	int S[maxn];
	int last;
	int n;
	int p;

	int newnode(int length){
		memset(nxt[p], 0, sizeof(nxt[p]));
		cnt[p] = num[p] = 0;
		len[p] = length;
		return p++;
	}

	void init(){
		p = 0;
		newnode(0);
		newnode(-1);
		last = 0;
		n = 0;
		S[n] = -1;
		fail[0] = 1;
	}

	int getFail(int x){
		while (S[n - len[x] - 1] != S[n]) x = fail[x];
		return x;
	}

	void add(int c) {
		c -= 'a';
		S[++n] = c;
		int cur = getFail(last);
		if (!nxt[cur][c]) {
			int now = newnode(len[cur] + 2);
			fail[now] = nxt[getFail(fail[cur])][c];
			nxt[cur][c] = now;
			num[now] = num[fail[now]] + 1;
		}
		last = nxt[cur][c];
		cnt[last] ++;
	}

	void count(){
		for (int i = p - 1; i >= 0; --i){
			cnt[fail[i]] += cnt[i];
		}
	}
};

PalindromeTree treex, treey;
char buf1[MAXN], buf2[MAXN];

ll ans;

void dfs(int u, int v)
{
	for (int i = 0; i < treex.ch; ++i){
		int x = treex.nxt[u][i];
		int y = treey.nxt[v][i];
		if (x&&y){
			ans += (ll)treex.cnt[x] * treey.cnt[y];
			dfs(x, y);
		}
	}
}


int main()
{
	int T; cin >> T; int ca = 0;
	while (T--)
	{
		treex.init();
		treey.init();
		scanf("%s%s", buf1, buf2);
		int len1 = strlen(buf1), len2 = strlen(buf2);
		for (int i = 0; i < len1; ++i) {
			treex.add(buf1[i]);
		}
		for (int i = 0; i < len2; ++i){
			treey.add(buf2[i]);
		}
		treex.count();
		treey.count();
		ans = 0;
		dfs(0, 0);
		dfs(1, 1);
		printf("Case #%d: %I64d
", ++ca, ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/chanme/p/4461901.html