BZOJ4650 NOI2016优秀的拆分(后缀数组)

  显然只要求出以每个位置开始的AA串数量就可以了,将其和反串同位置的结果乘一下,加起来就是答案。考虑对每种长度的字符串计数。若当前考虑的A串长度为x,我们每隔x个字符设一个关键点,求出相邻两关键点的后缀lcp和前缀lcs,交叉部分就是跨过这两个关键点的A串长度为x的AA串个数。差分一发就能对每个位置求了。

#include<iostream> 
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define N 30010
char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;}
int gcd(int n,int m){return m==0?n:gcd(m,n%m);}
int read()
{
    int x=0,f=1;char c=getchar();
    while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
    while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
int T,n,a[N],cnt[N],rk[2][N<<1],tmp[N<<1],sa[N],sa2[N],f[2][N][17],h[N],lg2[N],ans[2][N];
char s[N];
void make(int op)
{
    int m=26;
    memset(cnt,0,sizeof(cnt));memset(rk[op],0,sizeof(rk[op]));
    for (int i=1;i<=n;i++) cnt[rk[op][i]=a[i]]++;
    for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for (int i=n;i>=1;i--) sa[cnt[a[i]]--]=i;
    for (int k=1;k<=n;k<<=1)
    {
        int p=0;
        for (int i=n-k+1;i<=n;i++) sa2[++p]=i;
        for (int i=1;i<=n;i++) if (sa[i]>k) sa2[++p]=sa[i]-k;
        memset(cnt,0,sizeof(cnt));
        for (int i=1;i<=n;i++) cnt[rk[op][i]]++;
        for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for (int i=n;i>=1;i--) sa[cnt[rk[op][sa2[i]]]--]=sa2[i];
        memcpy(tmp,rk[op],sizeof(tmp));
        p=1;rk[op][sa[1]]=1;
        for (int i=2;i<=n;i++)
        {
            if (tmp[sa[i]]!=tmp[sa[i-1]]||tmp[sa[i]+k]!=tmp[sa[i-1]+k]) p++;
            rk[op][sa[i]]=p;
        }
        if (p==n) break;
        m=p;
    }
    for (int i=1;i<=n;i++)
    {
        h[i]=max(h[i-1]-1,0);
        while (a[i+h[i]]==a[sa[rk[op][i]-1]+h[i]]) h[i]++;
    }
    for (int i=1;i<=n;i++) f[op][i][0]=h[sa[i]];
    for (int j=1;j<17;j++)
        for (int i=1;i<=n;i++)
        f[op][i][j]=min(f[op][i][j-1],f[op][min(n,i+(1<<j-1))][j-1]);
    for (int i=2;i<=n;i++)
    {
        lg2[i]=lg2[i-1];
        if ((2<<lg2[i])<=i) lg2[i]++;
    }
}
int query(int x,int y,int op)
{
    if (x>y) swap(x,y);
    x++;if (x>y) return N;
    return min(f[op][x][lg2[y-x+1]],f[op][y-(1<<lg2[y-x+1])+1][lg2[y-x+1]]);
}
void solve(int op)
{
    memset(ans[op],0,sizeof(ans[op]));
    for (int i=1;i<=n;i++)
        for (int j=i;j+i<=n;j+=i)
        {
            int x=j,y=j+i;
            int lcp=query(rk[op][x+1],rk[op][y+1],op),lcs=query(rk[op^1][n-x+1],rk[op^1][n-y+1],op^1);
            lcp=min(lcp,i-1),lcs=min(lcs,i);
            if (lcp+lcs>=i) ans[op][x-lcs+1]++,ans[op][x-lcs+(lcp+lcs-i)+2]--;
        }
    for (int i=1;i<=n;i++) ans[op][i]+=ans[op][i-1];
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("bzoj4650.in","r",stdin);
    freopen("bzoj4650.out","w",stdout);
    const char LL[]="%I64d
";
#else
    const char LL[]="%lld
";
#endif
    T=read();
    while (T--)
    {
        scanf("%s",s+1);
        n=strlen(s+1);memset(a,0,sizeof(a));
        for (int i=1;i<=n;i++) a[i]=s[i]-'a'+1;
        make(0);
        for (int i=1;i<=n;i++) a[i]=s[n-i+1]-'a'+1;
        make(1);
        solve(0),solve(1);
        ll tot=0;
        for (int i=1;i<=n;i++) tot+=ans[0][i]*ans[1][n-i+2];
        cout<<tot<<endl;
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Gloid/p/10288891.html