bzoj 3230: 相似子串【SA+st表+二分】

总是犯低级错误,st表都能写错……
正反分别做一遍SA,预处理st表方便查询lcp,然后处理a[i]表示前i个后缀一共有多少个本质不同的子串,这里的子串是按字典序的,所以询问的时候直接在a上二分排名就能得到询问区间,然后用正反st表查lcp即可

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=200005;
int n,q,b[N],sa1[N],sa2[N],rk1[N],rk2[N],he1[N],he2[N],st1[20][N],st2[20][N],wa[N],wb[N],wv[N],wsu[N];
long long a[N];
char s[N];
long long read()
{
	long long r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
bool cmp(int r[],int a,int b,int l)
{
	return r[a]==r[b]&&r[a+l]==r[b+l];
}
void saa(char r[],int n,int m,int sa[],int rk[],int he[])
{
	int *x=wa,*y=wb;
	for(int i=0;i<=m;i++)
		wsu[i]=0;
	for(int i=1;i<=n;i++)
		wsu[x[i]=r[i]]++;
	for(int i=1;i<=m;i++)
		wsu[i]+=wsu[i-1];
	for(int i=n;i>=1;i--)
		sa[wsu[x[i]]--]=i;
	for(int j=1,p=1;j<=n&&p<n;j<<=1,m=p)
	{
		p=0;
		for(int i=n-j+1;i<=n;i++)
			y[++p]=i;
		for(int i=1;i<=n;i++)
			if(sa[i]>j)
				y[++p]=sa[i]-j;
		for(int i=1;i<=n;i++)
			wv[i]=x[y[i]];
		for(int i=0;i<=m;i++)
			wsu[i]=0;
		for(int i=1;i<=n;i++)
			wsu[wv[i]]++;
		for(int i=1;i<=m;i++)
			wsu[i]+=wsu[i-1];
		for(int i=n;i>=1;i--)
			sa[wsu[wv[i]]--]=y[i];
		swap(x,y);
		x[sa[1]]=1;
		p=1;
		for(int i=2;i<=n;i++)
			x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p:++p;
	}
	for(int i=1;i<=n;i++)
		rk[sa[i]]=i;
	for(int i=1,j,k=0;i<=n;he[rk[i++]]=k)
		for(k?k--:0,j=sa[rk[i]-1];r[i+k]==r[j+k];k++);
}
int ef(long long x)
{
	int l=1,r=n,ans=1;
	while(l<=r)
	{
		int mid=(l+r)>>1;
		if(a[mid]>=x)
			r=mid-1,ans=mid;
		else
			l=mid+1;
	}
	return sa1[ans];
}
long long ques1(int x,int y)
{
	if(x==y)
		return n-x+1;
	int l=min(rk1[x],rk1[y])+1,r=max(rk1[x],rk1[y]),k=b[r-l+1];
	return min(st1[k][l],st1[k][r-(1<<k)+1]);
}
long long ques2(int x,int y)
{
	if(x==y)
		return n-x+1;
	int l=min(rk2[x],rk2[y])+1,r=max(rk2[x],rk2[y]),k=b[r-l+1];
	return min(st2[k][l],st2[k][r-(1<<k)+1]);
}
int main()
{
	scanf("%d%d%s",&n,&q,s+1);
	saa(s,n,200,sa1,rk1,he1);
	reverse(s+1,s+1+n);
	saa(s,n,200,sa2,rk2,he2);
	b[0]=-1;
	for(int i=1;i<=n;i++)
		b[i]=b[i>>1]+1;
	for(int i=1;i<=n;i++)
		st1[0][i]=he1[i],st2[0][i]=he2[i];
	for(int i=1;i<=17;i++)
		for(int j=1;j+(1<<i)-1<=n;j++)
		{
			st1[i][j]=min(st1[i-1][j],st1[i-1][j+(1<<(i-1))]);
			st2[i][j]=min(st2[i-1][j],st2[i-1][j+(1<<(i-1))]);
		}
	for(int i=1;i<=n;i++)
		a[i]=a[i-1]+n-sa1[i]+1-he1[i];
	// for(int i=1;i<=n;i++)
		// cerr<<sa1[i]<<" "<<he1[i]<<" "<<a[i]<<endl;
	while(q--)
	{
		long long x=read(),y=read();
		if(max(x,y)>a[n])
		{
			puts("-1");
			continue;
		}
		long long xl=ef(x),xr=xl+he1[rk1[xl]]-1+(x-a[rk1[xl]-1]),yl=ef(y),yr=yl+he1[rk1[yl]]-1+(y-a[rk1[yl]-1]),xx,yy;
		// cerr<<xl<<" "<<xr<<"   "<<yl<<" "<<yr<<endl;
		xx=min(min(xr-xl+1,yr-yl+1),ques1(xl,yl)),yy=min(min(xr-xl+1,yr-yl+1),ques2(n-xr+1,n-yr+1));
		printf("%lld
",xx*xx+yy*yy);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lokiii/p/10436246.html