P4218 [CTSC2010]珠宝商

P4218 [CTSC2010]珠宝商


神题...

可以想到点分治,细节不写了。。。

(学了个新姿势,sam可以在前面加字符

但是一次点分治只能做到(O(m)),考虑(sqrt n)点分治,如果子树大小(>sqrt n)就用(O(m))的点分治做法,否则用蛤希暴力。

然而块大小设为(20,30)(sqrt n)快多了...

#include<bits/stdc++.h>
#define il inline
#define vd void
#define frog 19260817
typedef long long ll;
typedef unsigned long long ull;
il ll gi(){
	ll x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch))f^=ch=='-',ch=getchar();
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
#define qt 20
std::unordered_map<ull,int>qaq[50010/qt*2+1];
ull Base[50010];
int n,m;
ll ans;
char S[50010],T[50010];
int fir[50010],dis[100010],nxt[100010],id;
il vd link(int a,int b){nxt[++id]=fir[a],fir[a]=id,dis[id]=b;}
int siz[50010],f[50010],FA[50010],N,rt;bool vis[50010];
struct SAM{
	int slink[100010],trans[100010][26],pos[100010],son[100010][26],len[100010],endpos[100010],cnt,lst,leaf[50010],flg;
	char ss[100010];
	SAM(){cnt=0;lst=++cnt;len[lst]=0;}
	il vd extend(int ch,int i){
		int p=lst,np=++cnt;len[np]=len[p]+1;lst=np;endpos[np]=1;leaf[i]=np;pos[np]=i;
		while(p&&!trans[p][ch])trans[p][ch]=np,p=slink[p];
		if(!p)slink[np]=1;
		else{
			int q=trans[p][ch];
			if(len[q]==len[p]+1)slink[np]=q;
			else{
				int nq=++cnt;
				slink[nq]=slink[q],len[nq]=len[p]+1,memcpy(trans[nq],trans[q],sizeof trans[q]);
				while(p&&trans[p][ch]==q)trans[p][ch]=nq,p=slink[p];
				slink[np]=slink[q]=nq;
			}
		}
	}
	int t[100010],st[100010];
	il vd prepare(){
		for(int i=1;i<=m;++i)ss[i]=S[i];
		for(int i=1;i<=cnt;++i)++t[len[i]];
		for(int i=1;i<=cnt;++i)t[i]+=t[i-1];
		for(int i=cnt;i>1;--i)st[t[len[i]]--]=i;
		for(int i=cnt,x;i;--i){
			x=st[i];endpos[slink[x]]+=endpos[x];
			if(!pos[slink[x]])pos[slink[x]]=pos[x];
			if(pos[x]-len[slink[x]])son[slink[x]][ss[pos[x]-len[slink[x]]]-'a']=x;
		}
	}
	int tag[100010];
	il vd calc(){
		for(int i=1,x;i<=cnt;++i)x=st[i],tag[x]+=tag[slink[x]];
	}
	il vd dfs3(int x,int fa,int y,int _len){
		if(_len==len[y])y=son[y][T[x]-'a'];
		else if(ss[pos[y]-_len]!=T[x])y=0;
		if(!y)return;
		++tag[y];++_len;
		//printf("%d %d %d %d
",x,fa,y,_len);
		for(int i=fir[x];i;i=nxt[i]){
			if(fa==dis[i]||vis[dis[i]])continue;
			dfs3(dis[i],x,y,_len);
		}
	}
}sam,rsam;
il vd getrt(int x,int fa=-1){
	siz[x]=1,f[x]=0;
	for(int i=fir[x];i;i=nxt[i]){
		if(fa==dis[i]||vis[dis[i]])continue;
		FA[dis[i]]=x;
		getrt(dis[i],x);
		siz[x]+=siz[dis[i]];
		f[x]=std::max(f[x],siz[dis[i]]);
	}
	f[x]=std::max(f[x],N-siz[x]);
	if(f[rt]>f[x])rt=x;
}
std::vector<int>G;
il vd dfs(int x,int fa=-1){
	G.push_back(x);
	for(int i=fir[x];i;i=nxt[i]){
		if(fa==dis[i]||vis[dis[i]])continue;
		dfs(dis[i],x);
	}
}
il vd dfs2(int x,int y,int fa=-1){
	if(!y)return;
	ans+=sam.endpos[y];
	for(int i=fir[x];i;i=nxt[i]){
		if(fa==dis[i]||vis[dis[i]])continue;
		dfs2(dis[i],sam.trans[y][T[dis[i]]-'a'],x);
	}
}
std::vector<ull>A,B;
std::vector<int>LA,LB;
il vd dfs2_(int x,ull HA,ull HB,int len,int fa=-1){
	HA=(HA+T[x])*frog,HB+=Base[++len]*T[x];
	A.push_back(HA),LA.push_back(len+1),B.push_back(HB),LB.push_back(len);
	for(int i=fir[x];i;i=nxt[i]){
		if(fa==dis[i]||vis[dis[i]])continue;
		dfs2_(dis[i],HA,HB,len,x);
	}
}
il vd work(int x,int fa,ll o){
	if(siz[x]<=qt){
		A.clear(),B.clear();LA.clear();LB.clear();
		dfs2_(x,(ull)frog*T[fa],0,0,fa);
		for(int i=0;i<A.size();++i)
			for(int j=0;j<B.size();++j){
				ull H=A[i]+B[j]*Base[LA[i]];
				if(qaq[LA[i]+LB[j]].count(H))ans+=o*qaq[LA[i]+LB[j]][H];
			}
		return;
	}
	memset(sam.tag,0,(sam.cnt+1)*4);memset(rsam.tag,0,(rsam.cnt+1)*4);
	if(fa)sam.dfs3(x,fa,sam.son[1][T[fa]-'a'],1),rsam.dfs3(x,fa,rsam.son[1][T[fa]-'a'],1);
	else sam.dfs3(x,fa,1,0),rsam.dfs3(x,fa,1,0);
	sam.calc(),rsam.calc();
	for(int i=1;i<=m;++i)ans+=o*sam.tag[sam.leaf[i]]*rsam.tag[rsam.leaf[m-i+1]];
}
il vd solve(int x){
	if(siz[x]<=qt){
		G.clear();dfs(x);
		for(int i:G)dfs2(i,sam.trans[1][T[i]-'a']);
		for(int i:G)vis[i]=1;
		return;
	}
	work(x,0,1);
	vis[x]=1;
	for(int i=fir[x];i;i=nxt[i]){
		if(vis[dis[i]])continue;
		work(dis[i],x,-1);
	}
	for(int i=fir[x];i;i=nxt[i]){
		if(vis[dis[i]])continue;
		rt=0,N=siz[dis[i]],getrt(dis[i]),solve(rt);
	}
}
int main(){
#ifdef XZZSB
	freopen("in.in","r",stdin);
	freopen("out.out","w",stdout);
#endif
	sam.flg=1,rsam.flg=0;
	n=gi(),m=gi();int a,b;
	for(int i=1;i<n;++i)a=gi(),b=gi(),link(a,b),link(b,a);
	scanf("%s",T+1),scanf("%s",S+1);
	Base[0]=1;for(int i=1;i<=m;++i)Base[i]=Base[i-1]*frog;
	for(int i=1;i<=m;++i){
		ull Hash=0;
		for(int j=1;i+j-1<=m&&j<=qt*2+1;++j)Hash+=Base[j]*S[i+j-1],++qaq[j][Hash];
	}
	for(int i=1;i<=m;++i)sam.extend(S[i]-'a',i);
	sam.prepare();
	std::reverse(S+1,S+m+1);
	for(int i=1;i<=m;++i)rsam.extend(S[i]-'a',i);
	rsam.prepare();
	std::reverse(S+1,S+m+1);
	N=n;f[0]=1e9,rt=0,getrt(1),solve(rt);
	printf("%lld
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/xzz_233/p/11152253.html