NOI.AC#2144子串【SAM,倍增】

正题

题目链接:http://noi.ac/problem/2144


题目大意

给出一个字符串\(s\)和一个序列\(a\)。将字符串\(s\)的所有本质不同子串降序排序后,求有多少个区间\([l,r]\)使得子串\(s_{l,r}\)排名等于\(a_{l\sim r}\)的和。

\(1\leq n\leq 2\times 10^5\)


解题思路

因为是降序排序,所以每加一个字符排名是在下降的,而\(a_i\)的和又是不降的,所以对于每个左端点最多只有一个右端点,且可以考虑二分求出这个位置。

如何快速得到子串排名,开始不会还去\(\text{LA}\)群问了一下才知道。

\(SAM\)的一个节点代表多个串,不能通过节点来得到排名。后缀树上的一个节点也是代表多个串,但是这些串的排名是连续的(因为这些串都有相同的前缀)。

所以我们可以根据后缀树上确定每个节点的最小排名,然后用倍增找出子串\(s_{l,r}\)的节点,再根据长度确定具体排名。此时我们可以做到\(O(n\log^2 n)\),可以通过本题了。

但还可以继续优化,发现我们倍增的过程有大量重复,越往上排名越后,所以我们类似二分的判断方法直接用倍增跳到答案节点,然后在答案节点处再二分就好了。

时间复杂度\(O(n\log n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=4e5+10,T=20;
ll n,cnt,num,last,tot,len[N],fa[N],ch[N][26],pos[N],id[N];
ll w[N],dep[N],f[N][T+1],t[N][26],rk[N],p1[N],p2[N];
char s[N];
void Insert(ll c){
	ll p=last,np=last=++cnt;
	len[np]=len[p]+1;
	for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
	if(!p)fa[np]=1;
	else{
		ll q=ch[p][c];
		if(len[p]+1==len[q])fa[np]=q;
		else{
			ll nq=++cnt;len[nq]=len[p]+1;pos[nq]=pos[q];
			memcpy(ch[nq],ch[q],sizeof(ch[nq]));
			fa[nq]=fa[q];fa[q]=fa[np]=nq;
			for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
		}
	}
	return;
}
void dfs(ll x){
	for(ll i=25;i>=0;i--){
		ll y=t[x][i];
		if(!y)continue;
		dep[y]=dep[x]+1;
		f[y][0]=x;dfs(y);
	}
	rk[x]=tot;
	tot+=len[x]-len[fa[x]];
	return;
}
signed main()
{
	scanf("%s",s+1);n=strlen(s+1);
	for(ll i=1;i<=n;i++)scanf("%lld",&w[i]),w[i]+=w[i-1];
	last=cnt=1;
	for(ll i=n;i>=1;i--)Insert(s[i]-'a'),pos[last]=i,id[i]=last;
	for(ll i=2;i<=cnt;i++)
		t[fa[i]][s[pos[i]+len[fa[i]]]-'a']=i;
	dfs(1);
	for(ll j=1;j<=T;j++)
		for(ll i=1;i<=cnt;i++)
			f[i][j]=f[f[i][j-1]][j-1];
	for(ll p=1;p<=n;p++){
		ll x=id[p];
		for(ll i=T;i>=0;i--){
			ll y=f[x][i];if(y<=1)continue;
			if(rk[y]+1<=w[p+len[y]-1]-w[p-1])x=y;
		}
		ll l=len[fa[x]]+1,r=len[x];
		while(l<=r){
			ll mid=(l+r)>>1;
			if(rk[x]+len[x]-mid+1<=w[p+mid-1]-w[p-1])r=mid-1;
			else l=mid+1;
		}
		if(rk[x]+len[x]-l+1==w[p+l-1]-w[p-1])
			num++,p1[num]=p,p2[num]=p+l-1;
	}
	printf("%lld\n",num);
	for(ll i=1;i<=num;i++)
		printf("%lld %lld\n",p1[i],p2[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14596157.html