[bzoj3998][TJOI2015]弦论

后缀自动机丝薄题。

求给定字符串$s$的第$k$大的子串。分unique之后的和不unique的两种询问。


首先构建出SAM。

相同子串算一个的情况:

SAM上所有路径组成字符串$s$的全部子串,每个状态向下不管怎么走,形成的串都是以当前状态为前缀的。(废话)

所以我们只要知道以当前串为前缀的串有多少,不就可以像在平衡树上搜索一样找到第$k$大吗。

即从当前点开始,向下有多少路径。拓扑跑一下就行了。$sum[u]=sum sum[v] +1$。

另一种情况:

思路和第一种完全一样,区别在于所在位置不同的相同子串本质不同了。我们发现$right$集合的大小不就是当前状态的有多少个吗。所以我们先求出每个点$u$的$right$集合有多大($u$的$right$集合就是所有以$u$为$parent$(也有叫$link$,我叫$fa$)的点的$right$集合的并集呀),一个道理,拓扑跑一下。$sum[u]=sum sum[v]+|right(u)|$

代码可以简化的地方很多。。

看大佬们的代码都不用dfs去统计,蒟蒻想了半天发现自己傻掉了。因为我们知道不管在SAM的DAG中,还是在$parent$树中,儿子的$len$($max$?就是丽洁神犇ppt里的$max$)都比父亲大(DAG上说父子的话不太严密,但意思就是那样)。于是按$len$排序(注意可以基数排序!)然后再转移就一定不会错啦。

#include<bits/stdc++.h>
using namespace std;
const int N=1000010;
typedef long long ll;
int len[N],fa[N],ch[N][26];
int las,sz;
int opt,n,val[N],sum[N],v[N],q[N];
void ins(int c){
    int now=++sz;len[now]=len[las]+1;
    int p,q;val[now]=1;
    for(p=las;~p&&!ch[p][c];p=fa[p])
    ch[p][c]=now;
    if(!~p)fa[now]=0;
    else{
        q=ch[p][c];
        if(len[q]==len[p]+1)
            fa[now]=q;
        else{
            int r=++sz;
            fa[r]=fa[q],len[r]=len[p]+1;
            for(int i=0;i<26;i++)
            ch[r][i]=ch[q][i];
            for(;~p&&ch[p][c]==q;p=fa[p])
            ch[p][c]=r;
            fa[now]=fa[q]=r;
        }
    }
    las=now;
}
void init(){
    for(int i=0;i<=sz;i++)v[len[i]]++;
    for(int i=1;i<=n;i++)v[i]+=v[i-1];
    for(int i=sz;~i;i--)
    q[v[len[i]]--]=i;
    for(int i=sz+1;i;i--){
        int t=q[i];
        if(opt==1)val[fa[t]]+=val[t];
        else val[t]=1;
    }
    val[0]=0;
    for(int i=sz+1;i;i--){
        int t=q[i];sum[t]=val[t];
        for(int j=0;j<26;j++)
        sum[t]+=sum[ch[t][j]];
    }
}
void dfs(int u,int rk){
    if(rk<=val[u])return;
    rk-=val[u];
    for(int i=0;i<26;i++)
    if(ch[u][i]){
        int t=ch[u][i];
        if(rk<=sum[t]){
            putchar(i+'a');
            dfs(t,rk);
            return;
        }
        rk-=sum[t];
    }
}
void solve(){
    init();
    int k;scanf("%d",&k);
    dfs(0,k);
}
char s[N];
int main(){
    scanf("%s",s);fa[0]=-1;
    n=strlen(s);
    for(int i=0;i<n;i++)
    ins(s[i]-'a');
    scanf("%d",&opt);
    solve();
}

 

原文地址:https://www.cnblogs.com/orzzz/p/8306726.html