并不对劲的Loj6031:「雅礼集训 2017 Day1」字符串

题目传送门:->

看到题目的第一反应当然是暴力:对于串s建后缀自动机,每次询问中,求w对应的子串在s的SAM中的right集合。O(qmk)听上去显然过不了。

数据范围有个∑w<=1e5,也就是说,q*k<=1e5,当q更小或k更小时可以用不同的方法。

k更小时,会发现每个w的子串数可能会很小,子串数可能还没m多。这个时候将m个[li,ri]的询问可能会有重复的,所以可以用vector存一下对于每一个[L,R],满足li=L且ri=R的询问的编号是哪些。求出w对于每个[L,R]有多少个满足li=L且ri=R的询问的编号在[a,b](可以用vector自带的lower_bound),和w[L,R]在s中的出现次数。这样,时间复杂度是O(q*log m*k2)。

q更小时,对于每个w可以求出它的每个[1,i]的前缀会匹配到s的SAM的哪个点(记作pla[i])、能匹配s多长。SAM中的每个点的fail指针指向的点都是它的后缀。所以对于每个w[li,ri],可以先走到pla[ri],再倍增地顺着fail指针走。时间复杂度是O(q*m*log k)。

听说代码很难调?

#include <bits/stdc++.h>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define re register
#define LL long long
#define maxn 200010
#define block 333 
using namespace std;
inline LL read()
{
    LL x=0,f=1;
    char ch=getchar();
    while(isdigit(ch)==0 && ch!='-')ch=getchar();
    if(ch=='-')f=-1,ch=getchar();
    while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(LL x)
{
    LL f=0;char ch[20];
    if(!x){puts("0");return;}
    if(x<0){putchar('-');x=-x;}
    while(x)ch[++f]=x%10+'0',x/=10;
    while(f)putchar(ch[f--]);
    putchar('
');
}
int ch[maxn][30],dis[maxn],ord[maxn],c[maxn],fa[maxn],cnt,rt,lst,anc[maxn][21];
int n,m,q,k,ql[maxn],qr[maxn],pla[maxn],lth[maxn];
LL r[maxn];
vector<int>to[block][block];
char s[maxn],w[maxn];
int gx(char c){return c-'a';}
void go(int & u,int & len, int x)
{
	while(!ch[u][x]&&u)u=fa[u],len=dis[u];
	if(ch[u][x])u=ch[u][x],++len;
	else u=rt,len=0;
}
void extend(LL pos)
{
	int x=gx(s[pos]),p=lst,np=++cnt;lst=np;dis[np]=pos;
	for(;p&&!ch[p][x];p=fa[p])ch[p][x]=np;
	if(!p)fa[np]=rt;
	else
	{
		LL q=ch[p][x];
		if(dis[q]==dis[p]+1)fa[np]=q;
		else
		{
			LL nq=++cnt;dis[nq]=dis[p]+1;
			memcpy(ch[nq],ch[q],sizeof(ch[q]));
			fa[nq]=fa[q],fa[q]=fa[np]=nq;
			for(;p&&ch[p][x]==q;p=fa[p])ch[p][x]=nq;
		}
	}
}
void getr()
{
	for(int u=rt,i=1;i<=n;++i)u=ch[u][gx(s[i])],++r[u];
	rep(i,1,cnt)c[dis[i]]++;
	rep(i,1,n)c[i]+=c[i-1];
	rep(i,1,cnt)ord[c[dis[i]]--]=i;
	dwn(i,cnt,1)r[fa[ord[i]]]+=r[ord[i]];
}
void getfa(){rep(l,1,cnt){int i=ord[l];anc[i][0]=fa[i];rep(j,1,20)anc[i][j]=anc[anc[i][j-1]][j-1];}}
int main()
{
	lst=rt=++cnt;
	n=read(),m=read(),q=read(),k=read();
	scanf("%s",s+1);
	rep(i,1,n)extend(i);getr();
	rep(i,1,m){ql[i]=read()+1,qr[i]=read()+1;if(k<=block)to[ql[i]][qr[i]].push_back(i);}
	if(k<=block)
	{
		while(q--)
		{
			scanf("%s",w+1);
			int a=read()+1,b=read()+1;LL ans=0;
			rep(i,1,k)
			{
				int u=rt;
				rep(j,i,k)
				{
					if(ch[u][gx(w[j])])
					{
						u=ch[u][gx(w[j])];
						vector<int>::iterator L=lower_bound(to[i][j].begin(),to[i][j].end(),a);
						vector<int>::iterator R=upper_bound(to[i][j].begin(),to[i][j].end(),b);
						ans+=(R-L)*r[u];
					}
					else break;
				}
			}
			write(ans);
		}
	}
	else
	{
		getfa();
		while(q--)
		{
			scanf("%s",w+1);
			int a=read()+1,b=read()+1;LL ans=0;
			for(int len=0,u=rt,i=1;i<=k;i++)go(u,len,gx(w[i])),lth[i]=len,pla[i]=u;
			rep(i,a,b)
			{
				if(lth[qr[i]]<qr[i]-ql[i]+1)continue;
				else
				{
					int u=pla[qr[i]];
					dwn(j,19,0)if(dis[anc[u][j]]>=qr[i]-ql[i]+1)u=anc[u][j];
					ans+=r[u];
				}
			}
			write(ans);
		}
	}
    return 0;
}

  

原文地址:https://www.cnblogs.com/xzyf/p/9166968.html