【CTSC2010】珠宝商

题目大意

给出一棵(n)个点的树,每个节点有一个字符,再给出一个长度为(m)的字符串(S)。求树上所有路径所代表的字符串在(S)中的出现次数。
(n,mleq 50000)

题解

首先一个(O(n^2))的暴力是枚举每个点(dfs)一次,一边(dfs)一边在(S)(SAM)上跑,开个栈记一下之前跑到过哪些点可以做到(O(n^2))
因为是关于树上所有路径的问题所以考虑点分,设当前子树根为(u)
把路径分为(u)前面和后面的部分,求出所有可能的到(u)前面的那一段路径在(S)的每个位置结尾了多少次以及所有可能的从(u)出发的路径在每个位置开始了多少次,这个可以一边(dfs)一边在后缀树上跑,相当于对跑到的点的所有(endpos)加了(1)
然后枚举一下两段接起来的位置求出答案。
但是这样做复杂度是(O(k+m))(k)为子树大小)的,总时间复杂度是(O(nm))
点分树有一个性质:子树大小大于等于(k)的点最多只有(O(lfloor frac{n}{k} floor))个。
这个不难证明,只考虑子树大小大于等于(k)且没有任何儿子大于等于(k)的点,最多(lfloor frac{n}{k} floor)个,每往上加一个点都会把至少两棵树合并起来,否则把重心下移子树分布更为平均。总共加最多(lfloor frac{n}{k} floor -1)个点后合并成一整棵树。
这样我们可以对于大于(sqrt{n})的子树用上述做法,小于等于(sqrt{n})的子树用(O(n^2))的暴力。前者复杂度是(O((n+m)sqrt{n})),后者总大小(O(n)),单个大小(O(sqrt{n})),复杂度(O(nsqrt{n}))
注意所有大于(sqrt{n})的点的儿子总数最多(O(n)),所以去重的时候对于小于等于(sqrt{n})的儿子要特殊处理(因为这个(hack)掉了一大群人)。
预处理(sa)(dfs)序而不是每次(dfs)一遍下传标记可以极大降低常数。
(一开始写了个(O(mlog m))(dfs)然后(T)飞了)

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int mxn=100010,K=220;
int lens;
char s[mxn],t[mxn];
namespace Suf{
	int n,cur,trans[mxn][26],son[mxn][26],id[mxn],len[mxn],fa[mxn];
	int ins(int u,int c,int idd){
		int x=++n,v;
		len[x]=len[u]+1,id[x]=idd;
		for (;u&&!trans[u][c];trans[u][c]=x,u=fa[u]);
		if (!u) fa[x]=1;
		else if (len[v=trans[u][c]]==len[u]+1) fa[x]=v;
		else{
			len[++n]=len[u]+1,id[n]=idd;
			fa[n]=fa[v],fa[x]=fa[v]=n;
			for (int i=0;i<26;++i) trans[n][i]=trans[v][i];
			for (;u&&trans[u][c]==v;trans[u][c]=n,u=fa[u]);
		}
		return x;
	}
	int ln,curr,his[mxn],idx[mxn],idy[mxn],tot,ans[mxn>>1],sa[mxn];
	void dfs(int u){
		idx[u]=tot+1;
		if (id[u]+len[u]==lens) sa[++tot]=id[u]+1;
		for (int i=0;i<26;++i)
			if (son[u][i]) dfs(son[u][i]);
		idy[u]=tot;
	}
	void init(){
		n=cur=his[0]=1;
		for (int i=lens-1;i>=0;--i)
			cur=ins(cur,s[i]-'a',i);
		for (int i=2;i<=n;++i)
			son[fa[i]][s[id[i]+len[fa[i]]]-'a']=i;
		dfs(1);
	}
	void add(int c){
		if (ln==len[curr]) curr=son[curr][c];
		else if (c!=s[id[curr]+ln]-'a') curr=0;
		his[++ln]=curr;
		++ans[idx[curr]],--ans[idy[curr]+1];
	}
	void add2(int c){
		if (ln==len[curr]) curr=son[curr][c];
		else if (c!=s[id[curr]+ln]-'a') curr=0;
		his[++ln]=curr;
	}
	void del(){
		curr=his[--ln];
	}
	void clr(){
		memset(ans,0,sizeof(ans));
		ln=0,curr=1;
	}
}
namespace Pre{
	int n,cur,trans[mxn][26],son[mxn][26],id[mxn],len[mxn],fa[mxn];
	int ins(int u,int c,int idd){
		int x=++n,v;
		len[x]=len[u]+1,id[x]=idd;
		for (;u&&!trans[u][c];trans[u][c]=x,u=fa[u]);
		if (!u) fa[x]=1;
		else if (len[v=trans[u][c]]==len[u]+1) fa[x]=v;
		else{
			len[++n]=len[u]+1,id[n]=idd;
			fa[n]=fa[v],fa[x]=fa[v]=n;
			for (int i=0;i<26;++i) trans[n][i]=trans[v][i];
			for (;u&&trans[u][c]==v;trans[u][c]=n,u=fa[u]);
		}
		return x;
	}
	int siz[mxn],ln,curr,his[mxn],idx[mxn],idy[mxn],tot,ans[mxn>>1],sa[mxn];
	void dfs(int u){
		idx[u]=tot+1;
		if (id[u]+1==len[u]) sa[++tot]=id[u]+1;
		for (int i=0;i<26;++i)
			if (son[u][i]) dfs(son[u][i]);
		idy[u]=tot;
	}
	void dfss(int u){
		siz[u]=id[u]+1==len[u];
		for (int i=0;i<26;++i)
			if (son[u][i]) dfss(son[u][i]),siz[u]+=siz[son[u][i]];
	}
	void init(){
		n=cur=his[0]=1;
		for (int i=0;i<lens;++i)
			cur=ins(cur,s[i]-'a',i);
		for (int i=2;i<=n;++i)
			son[fa[i]][s[id[i]-len[fa[i]]]-'a']=i;
		dfss(1);
		dfs(1);
	}
	void add(int c){
		if (ln==len[curr]) curr=son[curr][c];
		else if (c!=s[id[curr]-ln]-'a') curr=0;
		his[++ln]=curr;
		++ans[idx[curr]],--ans[idy[curr]+1];
	}
	int add2(int c){
		his[++ln]=curr=trans[curr][c];
		return siz[curr];
	}
	void del(){
		curr=his[--ln];
	}
	void clr(){
		memset(ans,0,sizeof(ans));
		ln=0,curr=1;
	}
}
LL ans;
int A[mxn],B[mxn];
void getans(int x){
	for (int i=1;i<=lens;++i)
		Pre::ans[i]+=Pre::ans[i-1],A[Pre::sa[i]]=Pre::ans[i];
	for (int i=1;i<=lens;++i)
		Suf::ans[i]+=Suf::ans[i-1],B[Suf::sa[i]]=Suf::ans[i];
	for (int i=0;i<lens;++i)
		ans+=1ll*(A[i]+x)*B[i+1];
}
int n,m,head[mxn],siz[mxn],exist[mxn];
struct ed{int to,nxt;}edge[mxn];
void addedge(int u,int v){
	edge[++m]=(ed){v,head[u]},head[u]=m;
	edge[++m]=(ed){u,head[v]},head[v]=m;
}
int num,rt;
void getrt(int u,int fa,int N){
	siz[u]=1;
	int mx=0;
	for (int i=head[u],v;i;i=edge[i].nxt)
		if ((v=edge[i].to)!=fa&&!exist[v]){
			getrt(v,u,N);
			siz[u]+=siz[v];
			mx=max(mx,siz[v]);
		}
	mx=max(mx,N-siz[u]);
	if (mx<num) num=mx,rt=u;
}
void dfs(int u,int fa){
	ans+=Pre::add2(t[u]-'a');
	for (int i=head[u],v;i;i=edge[i].nxt)
		if ((v=edge[i].to)!=fa&&!exist[v]) dfs(v,u);
	Pre::del();
}
void solve_1(int u,int fa){
	Pre::ln=0;
	Pre::curr=1;
	dfs(u,0);
	for (int i=head[u],v;i;i=edge[i].nxt)
		if ((v=edge[i].to)!=fa&&!exist[v]) solve_1(v,u);
}
int stk[mxn],tp;
void solve_2(int u,int fa,int rt,int rtf){
	Pre::ln=0;
	Pre::curr=1;
	stk[++tp]=t[u]-'a';
	for (int i=tp;i>=0;--i) Pre::add2(stk[i]);
	dfs(rt,rtf);
	for (int i=head[u],v;i;i=edge[i].nxt)
		if ((v=edge[i].to)!=fa&&!exist[v]) solve_2(v,u,rt,rtf);
	--tp;
}
void cal(int u,int fa){
	Pre::add(t[u]-'a');
	Suf::add(t[u]-'a');
	for (int i=head[u],v;i;i=edge[i].nxt)
		if ((v=edge[i].to)!=fa&&!exist[v]) cal(v,u);
	Pre::del();
	Suf::del();
}
void solve(int u,int N){
	if (N<=K) return solve_1(u,0);
	Suf::clr(),Pre::clr();
	Suf::add(t[u]-'a');
	for (int i=head[u],v;i;i=edge[i].nxt)
		if (!exist[v=edge[i].to]) cal(v,u);
	getans(1);
	LL tmp=ans;
	ans=0;
	getrt(u,0,N);
	for (int i=head[u],v;i;i=edge[i].nxt)
		if (!exist[v=edge[i].to])
			if (siz[v]<=K){
				stk[0]=t[u]-'a';
				solve_2(v,u,v,u);
			}
			else{
				Suf::clr(),Pre::clr();
				Suf::add2(t[u]-'a');
				cal(v,u);
				getans(0);
			}
		else;
	ans=tmp-ans;
	exist[u]=1;
	for (int i=head[u],v;i;i=edge[i].nxt)
		if (!exist[v=edge[i].to]){
			num=1e9,getrt(v,0,siz[v]);
			solve(rt,siz[v]);
		}
}
int main()
{
	scanf("%d%d",&n,&lens);
	for (int i=1,x,y;i<n;++i)
		scanf("%d%d",&x,&y),addedge(x,y);
	scanf("%s%s",t+1,s);
	Suf::init();
	Pre::init();
	num=1e9,getrt(1,0,n);
	solve(rt,n);
	printf("%lld
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/zzqtxdy/p/12081144.html