bzoj 4199: [Noi2015]品酒大会 后缀树

题目大意:

给定一个长为n的字符串,每个下标有一个权(w_i),定义下标(i,j)是r相似的仅当(r leq LCP(suf(i),suf(j)))且这个相似的权为(w_i,w_j)
分别求出所有满足1 .. r相似的下标对数,及最大权.

题解:

我们发现这道题可以在后缀树上瞎搞
我们知道:(LCP(suf(i),suf(j)) = len(lca(i,j)))
所以我们可以对后缀树上的所有节点dp一下,求出每个点的子树包含的点对数
同时dp出子树中存在的权的最大值,次大值,最小值,次小值
然后累加答案即可.

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
inline void read(int &x){
	x=0;char ch;bool flag = false;
	while(ch=getchar(),ch<'!');if(ch == '-') ch = getchar(),flag = true;
	while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
const int maxn = 1000010;
struct Edge{
	int to,next;
}G[maxn];
int head[maxn],cnt;
void add(int u,int v){
	G[++cnt].to = v;
	G[cnt].next = head[u];
	head[u] = cnt;
}
struct Node{
	int nx[26];
	int len,fa;
}T[maxn];
int last,nodecnt = 0,n;
int a[maxn],siz[maxn],mx[maxn],cmx[maxn];
int mn[maxn],cmn[maxn];
inline void insert(char cha,int i){
	int c = cha - 'a',cur = ++ nodecnt,p;
	T[cur].len = T[last].len + 1;
	for(p = last;p != -1 && !T[p].nx[c];p = T[p].fa) T[p].nx[c] = cur;
	if(p == -1) T[cur].fa = 0;
	else{
		int q = T[p].nx[c];
		if(T[q].len == T[p].len + 1) T[cur].fa = q;
		else{
			int co = ++ nodecnt;T[co] = T[q];T[co].len = T[p].len + 1;
			for(;p != -1 && T[p].nx[c] == q;p = T[p].fa) T[p].nx[c] = co;
			T[cur].fa = T[q].fa = co;
		}
	}
	siz[last = cur]++;
	mx[cur] = mn[cur] = a[i];
}
ll num[maxn];
char s[maxn];
ll ans1[maxn],ans2[maxn];
inline void update(int &x,int &y,int z){
	if(z >= x) y = x,x = z;
	else if(z >= y) y = z;
}
inline void downpdate(int &x,int &y,int z){
	if(z <= x) y = x,x = z;
	else if(z <= y) y = z;
}
#define v G[i].to
void dfs(int u,int fa){
	for(int i = head[u];i;i = G[i].next){
		if(v == fa) continue;
		dfs(v,u);
		num[u] += 1LL*siz[u]*siz[v];
		siz[u] += siz[v];
		update(mx[u],cmx[u],mx[v]);update(mx[u],cmx[u],cmx[v]);
		downpdate(mn[u],cmn[u],mn[v]);
		downpdate(mn[u],cmn[u],cmn[v]);
	}
	if(mx[u] != mx[maxn-1] && cmx[u] != cmx[maxn-1]){
		ans2[T[u].len] = max(ans2[T[u].len],max(1LL*mx[u]*cmx[u],1LL*mn[u]*cmn[u]));
	}
	ans1[T[u].len] += num[u];
}
#undef v
int main(){
	memset(mx,-0x3f,sizeof mx);memset(cmx,-0x3f,sizeof cmx);
	memset(mn,0x3f,sizeof mn);memset(cmn,0x3f,sizeof cmn);
	memset(ans2,-0x3f,sizeof ans2);
	T[last = nodecnt = 0].fa = -1;
	read(n);scanf("%s",s);
	for(int i=0;i<n;++i) read(a[i]);
	reverse(s,s+n);reverse(a,a+n);
	for(int i=0;i<n;++i) insert(s[i],i);
	for(int i=1;i<=nodecnt;++i) add(T[i].fa,i);
	dfs(0,0);
	for(int i=n-2;i>=0;--i){
		ans1[i] += ans1[i+1];
		ans2[i] = max(ans2[i],ans2[i+1]);
	}
	for(int i=0;i<n;++i){
		if(ans2[i] == ans2[maxn-1]) ans2[i] = 0;
		printf("%lld %lld
",ans1[i],ans2[i]);
	}
	getchar();getchar();
	return 0;
}
原文地址:https://www.cnblogs.com/Skyminer/p/6540823.html