bzoj4566: [Haoi2016]找相同字符

一个串建SAM,一个串在上面跑DP

需要注意,走到当前节点的时候,有可能走的是近路,并不能把当前节点表示的所有子串匹配,这个时候就要记录一下走的步数(类似caioj那题),那些被当前点表示的,长度不超过步数的子串才有资格更新答案。

这个东西我用g来维护

然后他去更新其他人就没有这个限制了,用h表示覆盖的次数,减去f表示直接走到的次数,然后乘上这个点代表的子串数和出现次数,就是其他人更新我的答案

g和这个东西加起来就是答案了

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;

int a[210000],len;
struct SAM
{
    int w[30],dep,fail;
}ch[410000];int last,cnt;
void insert(int dep,int x)
{
    int pre=last,now=++cnt; ch[now].dep=dep;
    last=now;
    
    while(pre!=0&&ch[pre].w[x]==0)
        ch[pre].w[x]=now, pre=ch[pre].fail;
    if(pre==0)ch[now].fail=1;
    else
    {
        int nxt=ch[pre].w[x];
        if(ch[nxt].dep==ch[pre].dep+1)ch[now].fail=nxt;
        else
        {
            int nnxt=++cnt;
            ch[nnxt]=ch[nxt];
            ch[nnxt].dep=ch[pre].dep+1;
            
            ch[nxt].fail=ch[now].fail=nnxt;
            while(pre!=0&&ch[pre].w[x]==nxt)
                ch[pre].w[x]=nnxt, pre=ch[pre].fail;
        }
    }
}
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~init~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
int T,Right[410000];//根到当前节点组成的子串(当前节点管理的子串)出现次数,当前节点管理的子串数=ch[x].dep-ch[ch[x].fail].dep
int Rsort[410000],sa[410000];
void GetRight()
{
    memset(Rsort,0,sizeof(Rsort));
    for(int i=1;i<=cnt;i++)Rsort[ch[i].dep]++;
    for(int i=1;i<=len;i++)Rsort[i]+=Rsort[i-1];
    for(int i=cnt;i>=1;i--)sa[Rsort[ch[i].dep]--]=i;
    
    int now=1;
    memset(Right,0,sizeof(Right));
    for(int i=1;i<=len;i++) now=ch[now].w[a[i]], Right[now]++;
    for(int i=cnt;i>=1;i--)
    {
        int u=sa[i],v=ch[u].fail;
        Right[v]+=Right[u];
    }
    Right[1]=0;
}
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

//------------------------------------------------------SAM-----------------------------------------------------------------

LL f[410000],g[410000],h[410000];
void solve()
{
    int now=1,L=0;
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    memset(h,0,sizeof(h));
    for(int i=1;i<=len;i++)
    {
        int x=a[i];
        while(now!=1&&ch[now].w[x]==0) now=ch[now].fail, L=ch[now].dep;
        if(ch[now].w[x]!=0)
        {
            L++;
            now=ch[now].w[x];
            f[now]++,h[now]++;
            g[now]+=Right[now]*(L-ch[ch[now].fail].dep);
        }
    }
    for(int i=cnt;i>=1;i--)
    {
        int u=sa[i],v=ch[u].fail;
        f[v]+=f[u];
    }
    LL ans=0;
    for(int i=2;i<=cnt;i++)ans+=g[i]+(f[i]-h[i])*Right[i]*(ch[i].dep-ch[ch[i].fail].dep);
    printf("%lld
",ans);
}
char ss[210000];
int main()
{
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
    scanf("%s",ss+1);len=strlen(ss+1);
    last=cnt=1; ch[1].dep=0;
    for(int i=1;i<=len;i++)
        a[i]=ss[i]-'a'+1, insert(i,a[i]);
    GetRight();
    
    scanf("%s",ss+1);len=strlen(ss+1);
    for(int i=1;i<=len;i++)a[i]=ss[i]-'a'+1;
    solve();
    
    return 0;
}

 

原文地址:https://www.cnblogs.com/AKCqhzdy/p/10053632.html