Gym

#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
const int maxn=1000010;
char S[maxn],T[maxn];
struct PT
{
    struct in{
        int dep,fail,len,son[26];
    }p[maxn];
    int cnt,last;
    void init()
    {
        //memset(p,0,sizeof(p));
        cnt=last=1;p[0].dep=p[1].dep=0;
        p[0].fail=p[1].fail=1;
        p[0].len=0; p[1].len=-1;
    }
    int add(int c,int n)
    {
        int np=last;
        while(S[n]!=S[n-1-p[np].len]) np=p[np].fail;
        if(!p[np].son[c]){
            int v=++cnt,k=p[np].fail; p[v].len=p[np].len+2;
            while(S[n]!=S[n-p[k].len-1]) k=p[k].fail;
            p[v].fail=p[k].son[c];
            p[np].son[c]=v; //这一句放前面会出现矛盾,因为np可能=k
            p[v].dep=p[p[v].fail].dep+1;
        }
        last=p[np].son[c];
        return p[last].dep;
    }
}Tree;
int N,M,num[maxn],Next[maxn],extand[maxn]; ll ans;
void getnext(){// next[i]: 以第i位置开始的子串与T的公共前缀长度
     int i,length=strlen(T+1);
     Next[1]=length;
     for(i=0;i+1<length&&T[i+1]==T[i+2];i++);
     Next[2]=i;
     int a=2;   //!
     for(int k=3;k<=length;k++){//长度+1,位置-1。
          int p=a+Next[a]-1, L=Next[k-a+1];
          if(L>=p-k+1){
              int j=(p-k+1)>0?(p-k+1):0;//中断后可能是负的
              while(k+j<=length&&T[k+j]==T[j+1]) j++;// 枚举(p+1,length) 与(p-k+1,length) 区间比较
              Next[k]=j, a=k;
          }
          else Next[k]=L;
    }
}
void getextand(){
    memset(Next,0,sizeof(Next));
    getnext();
    int Slen=strlen(S+1),Tlen=strlen(T+1),a=0;
    int MinLen=Slen>Tlen?Tlen:Slen;
    while(a<MinLen&&S[a+1]==T[a+1]) a++;
    extand[1]=a; a=1;
    for(int k=2;k<=Slen;k++){
        int p=a+extand[a]-1,L=Next[k-a+1];
        if(L>=p-k+1){
            int j=(p-k+1)>0?(p-k+1):0;
            while(k+j<=Slen&&j+1<=Tlen&&S[k+j]==T[j+1]) j++;
            extand[k]=j;a=k;
        }
        else extand[k]=L;
    }
}
int main()
{
    scanf("%s%s",S+1,T+1);
    N=strlen(S+1); M=strlen(T+1);
    reverse(S+1,S+N+1); Tree.init();
    rep(i,1,N) num[i]=Tree.add(S[i]-'a',i);
    getextand();
    rep(i,1,N) ans+=(ll)num[i-1]*extand[i];
    printf("%lld
",ans);
    return 0;
}

不知道写什么。。。

原文地址:https://www.cnblogs.com/hua-dong/p/10400366.html