bzoj3756pty的字符串(后缀自动机+计数)

题目描述

题解

我们可以先对trie树建出广义SAM,然后维护一下right集合大小(注意right集合在广义SAM上的维护方式)。

然后把匹配穿往广义SAM上匹配,假设现在匹配到了x节点,那么x的所有祖先后可以被匹配上,那么一个节点的贡献即为r[x]*(l[x]-l[fa[x]])

维护这玩意的和就好了,最下面的节点特判一下。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 1600002
#define M 8000002 
using namespace std;
typedef long long ll;
char c[1],s[M];
int cnt,n,father,pa[N],len,tong[N],rnk[N];
ll sum[N],ans;
int l[N],ch[N][26],fa[N],r[N]; 
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
inline int ins(int last,int x){
    int p=last;
    if(ch[p][x]){
        int q=ch[p][x];
        if(l[p]+1==l[q]){r[q]++;return q;}
        else{
            int nq=++cnt;l[nq]=l[p]+1;r[nq]=1;//care!!!!!!!!!!!!!!!
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            fa[nq]=fa[q];fa[q]=nq;
            for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq;
            return nq;
        }
    } 
    else{
        int np=++cnt;l[np]=l[p]+1;r[np]=1;
        for(;p&&!ch[p][x];p=fa[p])ch[p][x]=np;
        if(!p)fa[np]=1;
        else{
            int q=ch[p][x];
            if(l[p]+1==l[q])fa[np]=q;
            else{
                int nq=++cnt;l[nq]=l[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fa[nq]=fa[q];fa[q]=fa[np]=nq;
                for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq;
            }
        }
        return np;
    }
}
int main(){
    n=rd();cnt=1;pa[1]=1;
    for(int i=2;i<=n;++i){
        father=rd();scanf("%s",c);
        pa[i]=ins(pa[father],c[0]-'a');
    }
    scanf("%s",s+1);len=strlen(s+1); 
    for(int i=1;i<=cnt;++i)tong[l[i]]++;
    for(int i=1;i<=n;++i)tong[i]+=tong[i-1];
    for(int i=1;i<=cnt;++i)rnk[tong[l[i]]--]=i;
    for(int i=cnt;i>=1;--i){int x=rnk[i];r[fa[x]]+=r[x];}
    r[1]=0;
    for(int i=1;i<=cnt;++i){int x=rnk[i];sum[x]=sum[fa[x]]+1ll*(l[x]-l[fa[x]])*r[x];}
    int now=1,le=0;
    for(int i=1;i<=len;++i){
        if(ch[now][s[i]-'a'])le++,now=ch[now][s[i]-'a'];
        else{
            for(;now&&!ch[now][s[i]-'a'];now=fa[now]);
            if(now)le=l[now]+1,now=ch[now][s[i]-'a'];
            else le=0,now=1;
        }
        if(now!=1)ans+=sum[fa[now]]+1ll*(le-l[fa[now]])*r[now];
    }
    cout<<ans;
    return 0;
}
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 1600002
#define M 8000002 
using namespace std;
typedef long long ll;
char c[1],s[M];
int cnt,n,father,pa[N],len,tong[N],rnk[N];
ll sum[N],ans;
int l[N],ch[N][26],fa[N],r[N]; 
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
inline int ins(int last,int x){
    int p=last;
    if(ch[p][x]){
        int q=ch[p][x];
        if(l[p]+1==l[q]){r[q]++;return q;}
        else{
            int nq=++cnt;l[nq]=l[p]+1;r[nq]=1;//care!!!!!!!!!!!!!!!
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            fa[nq]=fa[q];fa[q]=nq;
            for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq;
            return nq;
        }
    } 
    else{
        int np=++cnt;l[np]=l[p]+1;r[np]=1;
        for(;p&&!ch[p][x];p=fa[p])ch[p][x]=np;
        if(!p)fa[np]=1;
        else{
            int q=ch[p][x];
            if(l[p]+1==l[q])fa[np]=q;
            else{
                int nq=++cnt;l[nq]=l[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fa[nq]=fa[q];fa[q]=fa[np]=nq;
                for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq;
            }
        }
        return np;
    }
}
int main(){
    n=rd();cnt=1;pa[1]=1;
    for(int i=2;i<=n;++i){
        father=rd();scanf("%s",c);
        pa[i]=ins(pa[father],c[0]-'a');
    }
    scanf("%s",s+1);len=strlen(s+1); 
    for(int i=1;i<=cnt;++i)tong[l[i]]++;
    for(int i=1;i<=n;++i)tong[i]+=tong[i-1];
    for(int i=1;i<=cnt;++i)rnk[tong[l[i]]--]=i;
    for(int i=cnt;i>=1;--i){int x=rnk[i];r[fa[x]]+=r[x];}
    r[1]=0;
    for(int i=1;i<=cnt;++i){int x=rnk[i];sum[x]=sum[fa[x]]+1ll*(l[x]-l[fa[x]])*r[x];}
    int now=1,le=0;
    for(int i=1;i<=len;++i){
        if(ch[now][s[i]-'a'])le++,now=ch[now][s[i]-'a'];
        else{
            for(;now&&!ch[now][s[i]-'a'];now=fa[now]);
            if(now)le=l[now]+1,now=ch[now][s[i]-'a'];
            else le=0,now=1;
        }
        if(now!=1)ans+=sum[fa[now]]+1ll*(le-l[fa[now]])*r[now];
    }
    cout<<ans;
    return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/10184192.html