BZOJ_4516_[Sdoi2016]生成魔咒_后缀数组+ST表+splay

BZOJ_4516_[Sdoi2016]生成魔咒_后缀数组+ST表+splay

Description

魔咒串由许多魔咒字符组成,魔咒字符可以用数字表示。例如可以将魔咒字符 1、2 拼凑起来形成一个魔咒串 [1,2]。
一个魔咒串 S 的非空字串被称为魔咒串 S 的生成魔咒。
例如 S=[1,2,1] 时,它的生成魔咒有 [1]、[2]、[1,2]、[2,1]、[1,2,1] 五种。S=[1,1,1] 时,它的生成魔咒有 [1]、
[1,1]、[1,1,1] 三种。最初 S 为空串。共进行 n 次操作,每次操作是在 S 的结尾加入一个魔咒字符。每次操作后都
需要求出,当前的魔咒串 S 共有多少种生成魔咒。

Input

第一行一个整数 n。
第二行 n 个数,第 i 个数表示第 i 次操作加入的魔咒字符。
1≤n≤100000。,用来表示魔咒字符的数字 x 满足 1≤x≤10^9

Output

输出 n 行,每行一个数。第 i 行的数表示第 i 次操作后 S 的生成魔咒数量

Sample Input

7
1 2 3 3 3 1 2

Sample Output

1
3
6
9
12
17
22

在后面添加字符每次会产生若干个后缀。于是翻转一下相当于每次只添加一个后缀。
然后如何求不同的子串个数呢?
首先对于一个后缀$Suffix(i)$,会贡献$n-i+1$个子串,但这些子串会有一些重复的,于是找$rank[i]$的前驱和后继所在的后缀设为$j$。
那么会有$Lcp(suffix(i),suffix(j))$这么多是重复的。
于是用splay维护前驱后继,然后ST表求两个后缀的LCP。
 
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
#define N 100050
#define ls ch[p][0]
#define rs ch[p][1]
#define get(x) (ch[f[x]][1]==x)
typedef long long ll;
int n,ws[N],wv[N],wa[N],wb[N],rank[N],sa[N],height[N],r[N];
int f[N],ch[N][2],siz[N],rt,val[N],reimu;
//////////////////////////////////////////////
struct A {
    int num,id,v;
}a[N];
bool cmp1(const A &x,const A &y){return x.num<y.num;}
bool cmp2(const A &x,const A &y){return x.id<y.id;}
///////////////////////////////////////////////
void build_suffix_array() {
    int i,j,p,*x=wa,*y=wb,*t,m=n;
    for(i=0;i<m;i++) ws[i]=0;
    for(i=0;i<n;i++) ws[x[i]=r[i]]++;
    for(i=1;i<m;i++) ws[i]+=ws[i-1];
    for(i=n-1;i>=0;i--) sa[--ws[x[i]]]=i;
    for(j=p=1;p<n;j<<=1,m=p) {
        for(p=0,i=n-j;i<n;i++) y[p++]=i;
        for(i=0;i<n;i++) if(sa[i]-j>=0) y[p++]=sa[i]-j;
        for(i=0;i<n;i++) wv[i]=x[y[i]];
        for(i=0;i<m;i++) ws[i]=0;
        for(i=0;i<n;i++) ws[wv[i]]++;
        for(i=1;i<m;i++) ws[i]+=ws[i-1];
        for(i=n-1;i>=0;i--) sa[--ws[wv[i]]]=y[i];
        for(t=x,x=y,y=t,i=p=1,x[sa[0]]=0;i<n;i++) {
            if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j]) x[sa[i]]=p-1;
            else x[sa[i]]=p++;
        }
    }
    for(i=1;i<n;i++) rank[sa[i]]=i;
    for(i=p=0;i<n-1;height[rank[i++]]=p)
        for(p?p--:0,j=sa[rank[i]-1];r[i+p]==r[j+p];p++);
}
/////////////////////////////////////////////
int newnode(int x) {
    siz[++reimu]=1; val[reimu]=x; return reimu;
}
void pushup(int p) {
    siz[p]=siz[ls]+siz[rs]+1;
}
void rotate(int x) {
    int y=f[x],z=f[y],k=get(x);
    ch[y][k]=ch[x][!k]; f[ch[y][k]]=y;
    ch[x][!k]=y; f[y]=x; f[x]=z;
    if(z) ch[z][ch[z][1]==y]=x;
    pushup(y); pushup(x);
    if(rt==y) rt=x;
}
void splay(int x,int y) {
    for(int fa;(fa=f[x])!=y;rotate(x)) 
        if(f[fa]!=y)
            rotate(get(x)==get(fa)?fa:x);
}
void insert(int x) {
    int p=rt,l,r;
    while(p) {
        if(val[p]>=x) r=p,p=ls;
        else l=p,p=rs;
    }
    splay(l,0); splay(r,rt);
    ch[r][0]=newnode(x);
    f[reimu]=r; pushup(r); pushup(l);
}
int pre(int x) {
    int p=rt,ans;
    while(p) {
        if(val[p]>=x) p=ls;
        else ans=p,p=rs;
    }
    return val[ans];
}
int nxt(int x) {
    int p=rt,ans;
    while(p) {
        if(val[p]<=x) p=rs;
        else ans=p,p=ls;
    }
    return val[ans];
}
/////////////////////////////////////////
struct ST {
    int f[N][20],L[N];
    void init() {
        int i,j;
        memset(f,0x3f,sizeof(f));
        L[1]=0;
        for(i=2;i<=n;i++) L[i]=L[i>>1]+1;
        for(i=0;i<=n;i++) {
            f[i][0]=height[i];
        }
        for(j=1;(1<<j)<=n;j++) {
            for(i=0;i+(1<<j)-1<=n;i++) {
                f[i][j]=min(f[i][j-1],f[i+(1<<j-1)][j-1]);
            }
        }
    }
    int get_min(int l,int r) {
        int len=L[r-l+1];
        return min(f[l][len],f[r-(1<<len)+1][len]);
    }
}S;
int main() {
    scanf("%d",&n);
    int i;
    for(i=1;i<=n;i++) scanf("%d",&a[n-i+1].num),a[i].id=i;
    sort(a+1,a+n+1,cmp1);
    int j=0;a[0].num=23333443;
    for(i=1;i<=n;i++) {
        if(a[i].num!=a[i-1].num) j++;
        a[i].v=j;
    }
    sort(a+1,a+n+1,cmp2);
    for(i=0;i<n;i++) {
        r[i]=a[i+1].v;
    }
    r[n++]=0;
    build_suffix_array();
 
    /*for(i=0;i<n;i++) printf("%d ",r[i]); puts("");
    for(i=0;i<n;i++) printf("%d ",sa[i]); puts("");
    for(i=0;i<n;i++) printf("%d ",height[i]); puts("");
    for(i=0;i<n;i++) printf("%d ",rank[i]);puts("");*/
 
    ll ans=0;
    rt=newnode(-100000000);
    ch[rt][1]=newnode(100000000);
    f[ch[rt][1]]=rt;
    pushup(rt);
 
    S.init();
    for(i=n-2;i>=0;i--) {
        ans+=n-i-1;
        int pr=pre(rank[i]);
        int tmp=0;
        if(pr>=0) {
            tmp=S.get_min(pr+1,rank[i]);
        }
        int nx=nxt(rank[i]);
        if(nx<=n) {
            tmp=max(tmp,S.get_min(rank[i]+1,nx));
        }
        ans-=tmp;
        insert(rank[i]);
        printf("%lld
",ans);
    }
}
/*
7
1 2 3 3 3 1 2
*/
原文地址:https://www.cnblogs.com/suika/p/9022854.html