CF504E Misha and LCP on Tree 题解

首先序列上的问题可以hash加二分搞
到树上依然可以hash加二分搞, 配合 (O(1))(RMQ-LCA) 和长链剖分求 (k) 级祖先 可以做到 (O(mlog n))
但是我天生自带大常数, 会 (T) qwq(于是我放弃了双hash, 然后过了淦)


会在第六个点 (TLE) 的代码(双hash):

#include<bits/stdc++.h>
using namespace std;
#define li long long
li ksm(li a, li b, li p) {
	li res = 1ll;
	for(;b;b>>=1,a=(a*a)%p)
		if(b&1) res=(res*a)%p;
	return res%p;
}
const int maxn = 3e5 + 5;
const int mod1 = 1000000007;
const int mod2 = 100000007;
int base1, base2;
li p1[maxn], p2[maxn];
li ip1[maxn], ip2[maxn];
li rn1[maxn], rn2[maxn];
li nr1[maxn], nr2[maxn];
 
int n,m;
char s[maxn];
int ct, hed[maxn], ver[maxn<<1], nxt[maxn<<1];
void ad(int a,int b) {
	ver[++ct] = b;
	nxt[ct] = hed[a];
	hed[a] = ct;
}
 
int treedep[maxn], son[maxn], top[maxn], highbit[maxn];
vector<int> kf[maxn], ks[maxn];
 
int tot, st[21][maxn<<1], fis[maxn], dep[maxn];
int f[21][maxn];
void dfs(int x,int fa,int Dep) {
	rn1[x] = (rn1[fa]*base1+s[x])%mod1;
	rn2[x] = (rn2[fa]*base2+s[x])%mod2;
	nr1[x] = (p1[dep[fa]]*s[x]+nr1[fa])%mod1;
	nr2[x] = (p2[dep[fa]]*s[x]+nr2[fa])%mod2;
	
	f[0][x] = fa;
	st[0][fis[x]=++tot] = x;
	treedep[x]=dep[x] = Dep;
	for(int i=hed[x]; i; i=nxt[i]) {
		int y=ver[i]; if(y==fa) continue;
		dfs(y,x,Dep+1); st[0][++tot] = x;
		if(treedep[y]>treedep[son[x]]) {
			son[x]=y;
			treedep[x]=treedep[y];
		}
	}
}
void po(int x,int fa,int tp) {
	top[x]=tp;
	if(son[x]) po(son[x],x,tp);
	for(int i=hed[x];i;i=nxt[i]) {
		int y=ver[i]; if(y==fa||y==son[x]) continue;
		po(y,x,y);
	}
}
int calc(int x,int y) {
	return (dep[x]<dep[y]?x:y);
}
int lca(int x,int y) {
	int l=fis[x], r=fis[y]; if(l>r) swap(l,r);
	int k = log2(r-l+1);
	return calc(st[k][l], st[k][r-(1<<k)+1]);
}
void init() {
	p1[0] = p2[0] = 1ll;
	ip1[0] = ip2[0] = 1ll;
	for(int i=1; i<=n; ++i) {
		p1[i] = p1[i-1]*base1 % mod1;
		p2[i] = p2[i-1]*base2 % mod2;
		ip1[i] = ksm(p1[i],mod1-2,mod1);
		ip2[i] = ksm(p2[i],mod2-2,mod2);
//		cout<<p1[i]*ip1[i]%mod1<<' '<<p2[i]*ip2[i]%mod2<<'
';
	}
	dfs(1,0,1);
	po(1,0,1);
	for(int k=1;k<=20;++k)
		for(int i=1;i+(1<<k)-1<=tot;++i)
			st[k][i] = calc(st[k-1][i], st[k-1][i+(1<<(k-1))]);
	for(int k=1;k<=20;++k)
		for(int i=1;i<=n;++i)
			f[k][i] = f[k-1][f[k-1][i]];
	for(int i=1;i<=n;++i) highbit[i] = log2(i);
	for(int i=1;i<=n;++i) if(i==top[i]) {
		ks[i].push_back(i), kf[i].push_back(i);
		int ns=i, nf=i;
		for(int j=1;j<=treedep[i]-dep[i]+1;++j) {
			ns=son[ns], nf=f[0][nf];
			ks[i].push_back(ns), kf[i].push_back(nf);
		}
	}
}
int kfa(int x,int k) {
//	if(!k) return x;
//	int std = dep[x]-k;
//	for(int k=20;k>=0;--k)
//		if(dep[f[k][x]] > std) x = f[k][x];
//	return f[0][x];
	if(!k) return x;
	int r=highbit[k];
	x=f[r][x];
	k-=(1<<r);
//	if(!k) return x;
	if(dep[x]-k<dep[top[x]])
		return kf[top[x]][dep[top[x]]-(dep[x]-k)];
	else
		return ks[top[x]][(dep[x]-k)-dep[top[x]]];
}
li get1(int x,int lcaxy,int y,int len) {
	if(lcaxy==x) {
		y = kfa(y,dep[y]-dep[x]+1-len);
		int tp = f[0][x];
		li res = (rn1[y]-rn1[tp]*p1[dep[y]-dep[tp]]%mod1)%mod1;
		return (res%mod1+mod1)%mod1;
	}
	if(lcaxy==y) {
		y = kfa(x,len-1);
		int tp = f[0][y];
		li res = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		return (res%mod1+mod1)%mod1;
	}
	int len1 = dep[x] - dep[lcaxy] + 1;
	int len2 = dep[y] - dep[lcaxy];
	if(len<=len1)
	{
		int tp = kfa(x,len-1);
		tp = f[0][tp];
		li res = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		return (res%mod1+mod1)%mod1;
	}
	else
	{
		int tp = f[0][lcaxy];
		li res1 = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		int fl = kfa(y,len2-len+len1);
		li res2 = (rn1[fl]-rn1[lcaxy]*p1[len-len1]%mod1)%mod1;
		li res = res1*p1[len-len1]%mod1+res2;
		return (res%mod1+mod1)%mod1;
	}
}
li get2(int x,int lcaxy,int y,int len) {
	if(lcaxy==x) {
		y = kfa(y,dep[y]-dep[x]+1-len);
		int tp = f[0][x];
		li res = (rn2[y]-rn2[tp]*p2[dep[y]-dep[tp]]%mod2)%mod2;
		return (res%mod2+mod2)%mod2;
	}
	if(lcaxy==y) {
		y = kfa(x,len-1);
		int tp = f[0][y];
		li res = (nr2[x]-nr2[tp])*ip2[dep[tp]]%mod2;
		return (res%mod2+mod2)%mod2;
	}
	int len1 = dep[x] - dep[lcaxy] + 1;
	int len2 = dep[y] - dep[lcaxy];
	if(len<=len1)
	{
		int tp = kfa(x,len-1);
		tp = f[0][tp];
		li res = (nr2[x]-nr2[tp])*ip2[dep[tp]]%mod2;
		return (res%mod2+mod2)%mod2;
	}
	else
	{
		int tp = f[0][lcaxy];
		li res1 = (nr2[x]-nr2[tp])*ip2[dep[tp]]%mod2;
		int fl = kfa(y,len2-len+len1);
		li res2 = (rn2[fl]-rn2[lcaxy]*p2[len-len1]%mod2)%mod2;
		li res = res1*p2[len-len1]%mod2+res2;
		return (res%mod2+mod2)%mod2;
	}
}
int main()
{
	srand((unsigned)time(0));
	base1 = rand()%100+200, base2 = rand()%300+400;
	
	cin>>n; scanf("%s",s+1);
	for(int i=1;i<n;++i) {
		int x,y; scanf("%d%d",&x,&y);
		ad(x,y); ad(y,x);
	}
	init();
//	cout<<kfa(2,0);
//	cout << get1(2,2,5,2) << '
';
//	cout << get1(3,2,2,2);
	cin>>m; while(m--) {
		int a,b,c,d; scanf("%d%d%d%d", &a,&b,&c,&d);
		int lcaab = lca(a,b), lcacd = lca(c,d);
		int lenab = dep[a]+dep[b]-2*dep[lcaab]+1;
		int lencd = dep[c]+dep[d]-2*dep[lcacd]+1;
		int L=1, R=min(lenab,lencd);
		while(L!=R) {
			int mid = (L+R+1) >> 1;
			if(
				(get1(a,lcaab,b,mid)==get1(c,lcacd,d,mid))&&
				(get2(a,lcaab,b,mid)==get2(c,lcacd,d,mid))
			) L=mid;
			else R = mid-1;
		}
		if(L==1 && s[a]!=s[c]) L=0;
		cout << L << '
';
	}
	return 0;
}

AC代码(单hash)

#include<bits/stdc++.h>
using namespace std;
#define li long long
li ksm(li a, li b, li p) {
	li res = 1ll;
	for(;b;b>>=1,a=(a*a)%p)
		if(b&1) res=(res*a)%p;
	return res%p;
}
const int maxn = 3e5 + 5;
const int mod1 = 1000000007;
//const int mod2 = 100000007;
int base1, base2;
li p1[maxn];
li ip1[maxn];
li rn1[maxn];
li nr1[maxn];
 
int n,m;
char s[maxn];
int ct, hed[maxn], ver[maxn<<1], nxt[maxn<<1];
void ad(int a,int b) {
	ver[++ct] = b;
	nxt[ct] = hed[a];
	hed[a] = ct;
}
 
int treedep[maxn], son[maxn], top[maxn], highbit[maxn];
vector<int> kf[maxn], ks[maxn];
 
int tot, st[21][maxn<<1], fis[maxn], dep[maxn];
int f[21][maxn];
void dfs(int x,int fa,int Dep) {
	rn1[x] = (rn1[fa]*base1+s[x])%mod1;
	nr1[x] = (p1[dep[fa]]*s[x]+nr1[fa])%mod1;
	
	f[0][x] = fa;
	st[0][fis[x]=++tot] = x;
	treedep[x]=dep[x] = Dep;
	for(int i=hed[x]; i; i=nxt[i]) {
		int y=ver[i]; if(y==fa) continue;
		dfs(y,x,Dep+1); st[0][++tot] = x;
		if(treedep[y]>treedep[son[x]]) {
			son[x]=y;
			treedep[x]=treedep[y];
		}
	}
}
void po(int x,int fa,int tp) {
	top[x]=tp;
	if(son[x]) po(son[x],x,tp);
	for(int i=hed[x];i;i=nxt[i]) {
		int y=ver[i]; if(y==fa||y==son[x]) continue;
		po(y,x,y);
	}
}
int calc(int x,int y) {
	return (dep[x]<dep[y]?x:y);
}
int lca(int x,int y) {
	int l=fis[x], r=fis[y]; if(l>r) swap(l,r);
	int k = log2(r-l+1);
	return calc(st[k][l], st[k][r-(1<<k)+1]);
}
void init() {
	p1[0] = 1ll;
	ip1[0] = 1ll;
	for(int i=1; i<=n; ++i) {
		p1[i] = p1[i-1]*base1 % mod1;
		ip1[i] = ksm(p1[i],mod1-2,mod1);
	}
	dfs(1,0,1);
	po(1,0,1);
	for(int k=1;k<=20;++k)
		for(int i=1;i+(1<<k)-1<=tot;++i)
			st[k][i] = calc(st[k-1][i], st[k-1][i+(1<<(k-1))]);
	for(int k=1;k<=20;++k)
		for(int i=1;i<=n;++i)
			f[k][i] = f[k-1][f[k-1][i]];
	for(int i=1;i<=n;++i) highbit[i] = log2(i);
	for(int i=1;i<=n;++i) if(i==top[i]) {
		ks[i].push_back(i), kf[i].push_back(i);
		int ns=i, nf=i;
		for(int j=1;j<=treedep[i]-dep[i]+1;++j) {
			ns=son[ns], nf=f[0][nf];
			ks[i].push_back(ns), kf[i].push_back(nf);
		}
	}
}
int kfa(int x,int k) {
	if(!k) return x;
	int r=highbit[k];
	x=f[r][x];
	k-=(1<<r);
	if(dep[x]-k<dep[top[x]])
		return kf[top[x]][dep[top[x]]-(dep[x]-k)];
	else
		return ks[top[x]][(dep[x]-k)-dep[top[x]]];
}
li get1(int x,int lcaxy,int y,int len) {
	if(lcaxy==x) {
		y = kfa(y,dep[y]-dep[x]+1-len);
		int tp = f[0][x];
		li res = (rn1[y]-rn1[tp]*p1[dep[y]-dep[tp]]%mod1)%mod1;
		return (res%mod1+mod1)%mod1;
	}
	if(lcaxy==y) {
		y = kfa(x,len-1);
		int tp = f[0][y];
		li res = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		return (res%mod1+mod1)%mod1;
	}
	int len1 = dep[x] - dep[lcaxy] + 1;
	int len2 = dep[y] - dep[lcaxy];
	if(len<=len1)
	{
		int tp = kfa(x,len-1);
		tp = f[0][tp];
		li res = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		return (res%mod1+mod1)%mod1;
	}
	else
	{
		int tp = f[0][lcaxy];
		li res1 = (nr1[x]-nr1[tp])*ip1[dep[tp]]%mod1;
		int fl = kfa(y,len2-len+len1);
		li res2 = (rn1[fl]-rn1[lcaxy]*p1[len-len1]%mod1)%mod1;
		li res = res1*p1[len-len1]%mod1+res2;
		return (res%mod1+mod1)%mod1;
	}
}
int main()
{
	srand((unsigned)time(0));
	base1 = rand()%100+200;

	cin>>n; scanf("%s",s+1);
	for(int i=1;i<n;++i) {
		int x,y; scanf("%d%d",&x,&y);
		ad(x,y); ad(y,x);
	}
	init();
	cin>>m; while(m--) {
		int a,b,c,d; scanf("%d%d%d%d", &a,&b,&c,&d);
		int lcaab = lca(a,b), lcacd = lca(c,d);
		int lenab = dep[a]+dep[b]-2*dep[lcaab]+1;
		int lencd = dep[c]+dep[d]-2*dep[lcacd]+1;
		int L=1, R=min(lenab,lencd);
		while(L!=R) {
			int mid = (L+R+1) >> 1;
			if(
				get1(a,lcaab,b,mid)==get1(c,lcacd,d,mid)
			) L=mid;
			else R = mid-1;
		}
		if(L==1 && s[a]!=s[c]) L=0;
		cout << L << '
';
	}
	return 0;
}
原文地址:https://www.cnblogs.com/tztqwq/p/12784650.html