51nod1600 Simple KMP

题目描述

对于一个字符串|S|,我们定义fail[i],表示最大的x使得S[1..x]=S[i-x+1..i],满足(x<i)
显然对于一个字符串,如果我们将每个0<=i<=|S|看成一个结点,除了i=0以外i向fail[i]连边,这是一颗树的形状,根是0
我们定义这棵树是G(S),设f(S)是G(S)中除了0号点以外所有点的深度之和,其中0号点的深度为-1
定义key(S)等于S的所有非空子串S'的f(S')之和
给定一个字符串S,现在你要实现以下几种操作:
1.在S最后面加一个字符
2.询问key(S)

题解

遇到这种对所有子串统计的问题,考虑差分答案数组。

我们将答案数组二次差分之后,只需要算每个字符在所有以它结尾的子串中的贡献即可。

考虑这个(border)树的深度有什么意义。

观察或手玩即可发现,对于某个串来说,末尾字符位置的深度就是这个串前缀等于后缀的串的数,就是每一对前缀等于后缀都会对这个位置有1的贡献。

所以对于i位置,考虑前(i-1)个位置的字符串,我们现在只需要求所有以这个点结尾的串在前(i-1)的字符串的出现次数的和。

用链剖或(LCT)维护即可。

代码

#include<bits/stdc++.h>
#define N 200002
using namespace std;
typedef long long ll;
const int mod=1e9+7;
int la[N<<2],num[N],tot,head[N],fa[N],deep[N],son[N],dfn[N],_tag[N],top[N],n;
char s[N];
ll tr[N<<2],size[N<<2],ans[N];
inline ll rd(){
	ll x=0;char c=getchar();bool f=0;
	while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
	while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return f?-x:x;
}
struct edge{int n,to;}e[N];
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
struct SAM_t{
	int ch[N][26],fa[N],l[N],cnt,last;
	SAM_t(){cnt=last=1;}
	inline void ins(int x,int id){
		int p=last,np=++cnt;l[np]=l[p]+1;last=np;num[id]=cnt;
		for(;p&&!ch[p][x];p=fa[p])ch[p][x]=np;
		if(!p)fa[np]=1;
		else{
			int q=ch[p][x];
			if(l[p]+1==l[q])fa[np]=q;
			else{
				int nq=++cnt;l[nq]=l[p]+1;
				memcpy(ch[nq],ch[q],sizeof(ch[nq]));
				fa[nq]=fa[q];fa[q]=fa[np]=nq;
				for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq; 
			}
		} 
	}
}sam;
inline void pushdown(int cnt){
	la[cnt<<1]+=la[cnt];
	la[cnt<<1|1]+=la[cnt];	
	MOD(tr[cnt<<1]+=size[cnt<<1]*la[cnt]%mod);
	MOD(tr[cnt<<1|1]+=size[cnt<<1|1]*la[cnt]%mod);
	la[cnt]=0;
}
ll upd(int cnt,int l,int r,int L,int R){;
	if(l>=L&&r<=R){
		ll x=tr[cnt];
		MOD(tr[cnt]+=size[cnt]);
		la[cnt]++;
		return x;
	}
	int mid=(l+r)>>1;
	if(la[cnt])pushdown(cnt);ll ans=0;
	if(mid>=L)MOD(ans+=upd(cnt<<1,l,mid,L,R));
	if(mid<R)MOD(ans+=upd(cnt<<1|1,mid+1,r,L,R));
	MOD(tr[cnt]=tr[cnt<<1]+tr[cnt<<1|1]);
	return ans;
}
void build(int cnt,int l,int r){
	if(l==r){
		size[cnt]=sam.l[_tag[l]]-sam.l[sam.fa[_tag[l]]];
		return;
	}
    int mid=(l+r)>>1;
    build(cnt<<1,l,mid);build(cnt<<1|1,mid+1,r);
    MOD(size[cnt]=size[cnt<<1]+size[cnt<<1|1]);
}
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
void dfs1(int u){
    size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]){
   	    int v=e[i].to;fa[v]=u;deep[v]=deep[u]+1;
   	    dfs1(v);
   	    size[u]+=size[v];
   	    if(size[v]>size[son[u]])son[u]=v;
    }
}
void dfs2(int u){
	dfn[u]=++dfn[0];_tag[dfn[0]]=u;
	if(!top[u])top[u]=u;
	if(son[u])top[son[u]]=top[u],dfs2(son[u]); 
	for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]&&e[i].to!=son[u])dfs2(e[i].to);
}
ll work(int x){
	ll ans=0;
	while(top[x]!=top[1]){
		MOD(ans+=upd(1,1,sam.cnt,dfn[top[x]],dfn[x]));
		x=fa[top[x]]; 
	}
	MOD(ans+=upd(1,1,sam.cnt,dfn[1],dfn[x]));
	return ans;
}
int main(){
	n=rd();
	scanf("%s",s+1);
	for(int i=1;i<=n;++i)sam.ins(s[i]-'a',i);
	for(int i=1;i<=sam.cnt;++i)if(sam.fa[i])add(sam.fa[i],i);
	dfs1(1);dfs2(1);
	build(1,1,sam.cnt);
	for(int i=1;i<=n;++i){
		int x=work(num[i]);
		MOD(ans[i]=ans[i-1]+x);
	} 
	for(int i=1;i<=n;++i){
		MOD(ans[i]+=ans[i-1]);
		printf("%lld
",ans[i]);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/10804489.html