[基本操作]后缀自动机

来介绍一些基本操作

首先,介绍一下 Suffix Automaton

后缀自动机大概由两部分组成—— DAWG 和 Parent Tree

1.DAWG

DAWG 的中文名字叫做“单词的有向无环图”

它由一个初始节点 init ,若干条转移边,若干个节点组成

DAWG 表示的是状态的转移关系,我们可以记一个点能识别的终止位置集合为 $end-pos(i)$,每个点的子串是一个前缀的一些后缀,这些后缀的长度都在 [minlen,maxlen] 这个区间里

2.Parent Tree

Parent Tree 类似 AC 自动机的 fail 树,是由 $end-pos$ 集合的包含关系构成的一棵树,满足 fa[i] 的 maxlen + 1 等于 i 的 minlen

由这个我们可以知道对 DAWG 拓扑排序相当于对 maxlen 数组快速排序/基数排序,由这个我们也可以知道其实并不用记录每个点的 minlen

Parent Tree 是反串的后缀树

由此我们可以做题

bzoj3879 SvT

给你一个串和若干组询问,每组询问包括若干个后缀,你要求出这些后缀两两间最长公共前缀长度的和

sol:后缀 i 和后缀 j 的 lcp 相当于后缀树上 i 的位置和 j 的位置的 LCA 深度

我们把串反过来,然后建 SAM

然后我们虚树 + 树形 dp 就可以了

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-')f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 1200010;
const LL mod = 23333333333333333LL;
int n,a[maxn],pos[maxn],rnk[maxn];
int tr[maxn][26];
int fa[maxn],len[maxn],dfn,root,last;
char s[maxn];
void extend(int c)
{
    int p = last,np = last = ++dfn;
    len[np] = len[p] + 1;
    while(p && !tr[p][c])tr[p][c] = np,p = fa[p];
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(len[q] == len[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            len[nq] = len[p] + 1;memcpy(tr[nq],tr[q],sizeof(tr[nq]));fa[nq] = fa[q],fa[np] = fa[q] = nq;
            while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p];
        }
    }
}
int first[maxn],to[maxn],nx[maxn],cnt;
LL ans;
int val[maxn];
inline void add(int u,int v){to[++cnt] = v;nx[cnt] = first[u];first[u] = cnt;}
inline void ins(int u,int v){add(u,v);add(v,u);}
int size[maxn],dep[maxn],ff[maxn],bl[maxn],ind[maxn],_tim;
inline void dfs1(int x)
{
    size[x] = 1;ind[x] = ++_tim;
    for(int i=first[x];i;i=nx[i])
    {
        if(to[i] == ff[x])continue;
        ff[to[i]] = x;
        dep[to[i]] = dep[x] + 1;
        dfs1(to[i]);
        size[x] += size[to[i]];
    }
}
inline void dfs2(int x,int col)
{
    bl[x] = col;
    int k = 0;
    for(int i=first[x];i;i=nx[i])
        if(to[i] != ff[x] && size[to[i]] > size[k])k = to[i];
    if(!k)return;
    dfs2(k,col);
    for(int i=first[x];i;i=nx[i])
        if(to[i] != ff[x] && to[i] != k)dfs2(to[i],to[i]);
}
inline int lca(int x,int y)
{
    while(bl[x] != bl[y])
    {
        if(dep[bl[x]] < dep[bl[y]])swap(x,y);
        x = ff[bl[x]];
    }return dep[y] < dep[x] ? y : x;
}
inline bool cmp(const int &x,const int &y){return ind[x] < ind[y];}
int stk[maxn],f[maxn];
inline void dp(int x)
{
    f[x] = val[x] ? 1 : 0;
    for(int i=first[x];i;i=nx[i])
    {
        dp(to[i]);
        ans += (LL)f[x] * f[to[i]] * len[x];
        f[x] += f[to[i]];
    }
    first[x] = 0;
    //cout<<x<<endl;
}
int main()
{
#ifdef Ez3real
    freopen("ww.in","r",stdin);
#endif
    root = last = ++dfn;
    n = read();int q = read();scanf("%s",s + 1);
    reverse(s + 1,s + n + 1);
    for(int i=1;i<=n;i++)extend(s[i] - 'a'),pos[n - i + 1] = last;
    for(int i=1;i<=dfn;i++)add(fa[i],i);dfs1(root);dfs2(root,root);
    memset(first,0,sizeof(first));
    while(q--)
    {
        cnt = 0;
        int k = read();//cout<<k<<"!!"<<endl;
        for(int i=1;i<=k;i++)a[i] = pos[read()];
        sort(a + 1,a + k + 1,cmp);
        int nn = 0;a[++nn] = a[1];
        for(int i=2;i<=k;i++)
            if(a[i] != a[i - 1])a[++nn] = a[i];
        for(int i=1;i<=nn;i++)val[a[i]] = 1;
        int top = 0;
        for(int i=1;i<=nn;i++)
        {
            if(!top){stk[++top] = a[i];continue;}
            int x = a[i],l = lca(x,stk[top]);
            while(ind[l] < ind[stk[top]])
            {
                if(ind[l] >= ind[stk[top - 1]])
                {
                    add(l,stk[top--]);
                    if(l != stk[top])stk[++top] = l;
                    break;
                }else add(stk[top - 1],stk[top]),top--;
            }
            stk[++top] = x;
        }
        while(top > 1)add(stk[top - 1],stk[top]),top--;
        ans = 0;
        dp(stk[1]);
        printf("%lld
",ans);
        for(int i=1;i<=nn;i++)val[a[i]] = 0;
    }
}
View Code

upd:把后缀数组里那几个地方拿出来建一个“虚后缀数组”,求出“虚 height ”

然后单调栈就可以了。。。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x=0,f=1;char ch;
    for(ch=getchar();!isdigit(ch);ch=getchar())if(ch=='-')f=-f;
    for(;isdigit(ch);ch=getchar())x=10*x+ch-'0';
    return x*f;
}
const int maxn = 1e6 + 10;
const LL mod = 23333333333333333LL;
int n,m;
char s[maxn];LL ans;
#define equ(x) (y[sa[i] + x] == y[sa[i - 1] + x])
int rnk[maxn],tmp[maxn],sa[maxn],hei[maxn];
int x[maxn],y[maxn],wc[maxn];
void radix_sort(int m)
{
    for(int i=1;i<=m;i++)wc[i] = 0;
    for(int i=1;i<=n;i++)wc[x[y[i]]]++;
    for(int i=1;i<=m;i++)wc[i] += wc[i - 1];
    for(int i=n;i>=1;i--)sa[wc[x[y[i]]]--] = y[i];
}
void makesa(char *s,int n,int m)
{
    for(int i=1;i<=n;i++)x[i] = s[i],y[i] = i;radix_sort(m);
    for(int j=1,p=0;j<=n;j<<=1,m = p,p = 0)
    {
        for(int i=n-j+1;i<=n;i++)y[++p] = i;
        for(int i=1;i<=n;i++)
              if(sa[i] > j)y[++p] = sa[i] - j;
        radix_sort(m);swap(x,y);x[sa[1]] = p = 1;
           for(int i=2;i<=n;i++)x[sa[i]] = equ(0) && equ(j) ? p : ++p;
        if(p == n)break;
    }
    for(int i=1;i<=n;i++)rnk[sa[i]] = i;
        for(int i=1,k=0;i<=n;hei[rnk[i++]]=k)
            for(k ? k-- : 0;i+k<=n && s[i+k] == s[sa[rnk[i]-1]+k];k++);
}
int st[maxn][25],lg[maxn];
void initST()
{
     lg[0] = -1;lg[1] = 0;
    for(int i=2;i<=n;i++)lg[i] = lg[i >> 1] + 1;
    for(int i=1;i<=n;i++)st[i][0] = hei[i];
    for(int i=1;i<=lg[n];i++)
         for(int j=1;j + (1 << i) - 1<=n;j++)st[j][i] = min(st[j][i - 1],st[j + (1 << i - 1)][i - 1]);
}
inline int lcp(int u,int v)
{
    if(u == v)return n - u + 1;
    u = rnk[u],v = rnk[v];
    if(u > v)swap(u,v);u++;
    int loog = lg[v - u + 1];
    return min(st[u][loog],st[v - (1 << loog) + 1][loog]);
}
inline int rmq(int u,int v)
{
    int loog = lg[v - u + 1];
    return min(st[u][loog],st[v - (1 << loog) + 1][loog]);
}
int k,q[maxn],stk[maxn],tot[maxn];
int main()
{
    n = read(),m = read();
    scanf("%s",s + 1);
    makesa(s,n,200);initST();
    while(m--)
    {
        k = read();ans = 0;
        for(int i=1;i<=k;i++)q[i] = rnk[read()];
        sort(q + 1,q + k + 1);
        k = unique(q + 1,q + k + 1) - q - 1;
        int top = 0;LL sum = 0,cur,cnt;
        for(int i=2;i<=k;i++)
        {
            cur = rmq(q[i - 1] + 1,q[i]),cnt = 0;
            while(top && cur <= stk[top])
            {
                cnt += tot[top];
                sum = ((sum - (LL)stk[top] * (LL)tot[top])%mod + mod) % mod;
                top--;
            }
            stk[++top] = cur;tot[top] = cnt + 1;
            sum = ((sum + (LL)stk[top] * (LL)tot[top])% mod + mod) % mod;
            (ans += sum) %= mod;
        }
        ans = ((ans % mod) + mod) % mod;
        cout<<ans<<endl;
    }
}
View Code

bzoj2882 工艺

求字符串的最小表示法,也就是说,把字符串组成一个环,从任意一个位置开始读一圈,求读出来的字符串字典序最小的方案

sol:环 -> 二倍链

然后从 init 开始走,每次走字典序最小的那个转移边,走 n 步就是最小表示法

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-')f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 1600010;
int n,a[maxn];
map<int,int> tr[maxn];
int fa[maxn],len[maxn],dfn,root,last;
void extend(int c)
{
    int p = last,np = last = ++dfn;
    len[np] = len[p] + 1;
    while(p && !tr[p][c])tr[p][c] = np,p = fa[p];
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(len[q] == len[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            len[nq] = len[p] + 1,tr[nq] = tr[q],fa[nq] = fa[q],fa[np] = fa[q] = nq;
            while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p];
        }
    }
}
int main()
{
    root = last = ++dfn;
    n = read();map<int,int>::iterator it;
    for(int i=1;i<=n;i++)a[n + i] = a[i] = read();
    int kn = n + n,p = root;
    for(int i=1;i<kn;i++)extend(a[i]);
    while(n--)
    {
        it = tr[p].begin();
        printf("%d",it -> first);
        if(n)putchar(' ');
        p = it -> second;
    }
}
View Code

bzoj4516 生成魔咒

一开始有一个空串,每次加入一个字符,询问当前本质不同的子串数量

sol:一个串本质不同的子串数量就是 $sum maxlen_i - maxlen_{fa[i]}$

因为每次只会加一个字符,每次只要维护增量即可,也就是每次把新的那个 np 对答案的贡献加进去

#include<bits/stdc++.h>
#define LL long long
const int maxn = 200050;
using namespace std;
map<int,int> tr[maxn];
int val[maxn],fa[maxn];
int n,m;
LL LastAns;
 
struct SAM
{
    int SIZE,last,root;
    SAM(){SIZE = last = root = 1;}
    inline int cal(int x){return val[x] - val[fa[x]];}
    inline void extend(int x)
    {
        int p = last,np = last = ++SIZE;
        val[np] = val[p] + 1;
        while(p && !tr[p][x]) tr[p][x] = np,p = fa[p];
        if(!p)fa[np] = root, LastAns += cal(np);
        else
        {
            int q = tr[p][x];
            if(val[p] + 1 == val[q]){fa[np] = q;LastAns += cal(np);}
            else
            {
                int nq = ++SIZE;
                val[nq] = val[p] + 1;tr[nq] = tr[q];
                fa[nq] = fa[q];LastAns += cal(nq) - cal(q);
                fa[np] = fa[q] = nq;LastAns += cal(np) + cal(q);
                while(p && tr[p][x] == q)tr[p][x] = nq, p = fa[p];
            }
        }
    }
}S;
 
int main()
{
    scanf("%d",&n);int x;
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&x);
        S.extend(x);
        printf("%lld
",LastAns);
    }
}
View Code

更新

bzoj4199 品酒大会

给一个字符串 S ,每个位置都有一个权值 $w_i$ ,对每一个 $i ∈ [1,n]$ ,求出 $lcp(a,b) = i$ 的后缀数量和 $w_a imes w_b$ 的最大值

sol:lcp -> 后缀树上 lca

第一问就是枚举一下 lca 然后枚举 lca 的相邻子节点 dp 一下就可以了

第二问记一下每个点的最大最小然后乘一下

 
#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
void fre()
{
    freopen("mydata.in","r",stdin);
    freopen("mydata.out","w",stdout);
}
const int maxn = 600010;
int n,a[maxn];
char s[maxn];
int tr[maxn][26],mxlen[maxn],fa[maxn],dfn,last,root;
int val[maxn];
LL cnt[maxn],ans[maxn],mn[maxn],mx[maxn];
int size[maxn];
void extend(int c,int v)
{
    int p = last,np = last = ++dfn;
    mxlen[np] = mxlen[p] + 1;
    size[np]++;mx[np] = mn[np] = v;
    for(;p && !tr[p][c];p = fa[p])tr[p][c] = np;
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            mxlen[nq] = mxlen[p] + 1;
            memcpy(tr[nq],tr[q],sizeof(tr[nq]));
            fa[nq] = fa[q];
            fa[np] = fa[q] = nq;
            for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq;
        }
    }
}
int first[maxn],to[maxn << 1],nx[maxn << 1],cwt;
inline void add(int u,int v)
{
    to[++cwt] = v;
    nx[cwt] = first[u];
    first[u] = cwt;
}
void dfs(int x)
{
    if(!mx[x] && !mn[x])mx[x] = -1e16,mn[x] = 1e16;
    for(int i=first[x];i;i=nx[i])
    {
        dfs(to[i]);
        if(mx[x] != -1e16 && mn[x] != 1e16 && mx[to[i]] != -1e16 && mn[to[i]] != 1e16)
            ans[mxlen[x]] = max(ans[mxlen[x]],max(mx[x] * mx[to[i]],mn[x] * mn[to[i]]));
        cnt[mxlen[x]] += 1ll * size[x] * size[to[i]];size[x] += size[to[i]];
        mx[x] = max(mx[x],mx[to[i]]);mn[x] = min(mn[x],mn[to[i]]);
    }
}
int main()
{
#ifdef Ez3real
    fre();
#endif
    root = last = ++dfn;
    n = read();scanf("%s",s + 1);
    //reverse(s + 1,s + n + 1);
    for(int i=1;i<=n;i++)
    {
        a[i] = read();
        //extend(s[i] - 'a',a[i]);
    }
    for(int i=n;i>=1;i--)extend(s[i] - 'a',a[i]);
    for(int i=2;i<=dfn;i++)add(fa[i],i);
    //for(int i=0;i<=n;i++)ans[i] = -1e16;
    memset(ans,-63,sizeof(ans));
    dfs(1);
    for(int i=n-1;i>=0;i--)cnt[i] += cnt[i + 1],ans[i] = max(ans[i],ans[i + 1]);
    for(int i=0;i<n;i++)
    {
        if(cnt[i])
            printf("%lld %lld
",cnt[i],ans[i]);
        else puts("0 0");
    } 
}
View Code

bzoj4566 找相同字符

给两个字符串 $S_1,S_2$ 求他们有多少个不同的相同子串,两个子串有一个字符位置不同就算不同

sol:广义后缀自动机,记一下第一个串到过多少点,第二个串到过多少点,如果两个串都到过一个点,答案就加上这个点的子串个数

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
void fre()
{
    freopen("mydata.in","r",stdin);
    freopen("mydata.out","w",stdout);
}
const int maxn = 800010;
char s1[maxn],s2[maxn];
int n,m;
int tr[maxn][26],fa[maxn],mxlen[maxn],root,last,dfn;
int c1[maxn],c2[maxn];
void extend(int c)
{
    int p = last,np = last = ++dfn;
    mxlen[np] = mxlen[p] + 1;
    for(;p && !tr[p][c];p = fa[p])tr[p][c] = np;
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            mxlen[nq] = mxlen[p] + 1;
            memcpy(tr[nq],tr[q],sizeof(tr[nq]));
            fa[nq] = fa[q];
            fa[q] = fa[np] = nq;
            for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq;
        }
    }
}
int c[maxn],rk[maxn];
void getsize()
{
    for(int i=1;i<=dfn;i++)++c[mxlen[i]];
    for(int i=1;i<=n;i++)c[i] += c[i - 1];
    for(int i=1;i<=dfn;i++)rk[c[mxlen[i]]--] = i;
    for(int i=dfn;i>=1;i--)
    {
        c1[fa[rk[i]]] += c1[rk[i]]; 
        c2[fa[rk[i]]] += c2[rk[i]];
    }
}
int main()
{
#ifdef Ez3real
    fre();
#endif
    root = last = ++dfn;
    scanf("%s%s",s1 + 1,s2 + 1);
    n = strlen(s1 + 1);m = strlen(s2 + 1);
    for(int i=1;i<=n;i++)extend(s1[i] - 'a');
    last = root;
    for(int i=1;i<=m;i++)extend(s2[i] - 'a');
    int now = root;
    for(int i=1;i<=n;i++)
    {
        int p = s1[i] - 'a';
        now = tr[now][p];
        c1[now]++;
    }
    now = root;
    for(int i=1;i<=m;i++)
    {
        int p = s2[i] - 'a';
        now = tr[now][p];
        c2[now]++;
    }getsize();
    LL ans = 0;
    for(int i=1;i<=dfn;i++){ans += 1ll * c1[i] * c2[i] * (mxlen[i] - mxlen[fa[i]]);}
    cout<<ans;
}
View Code

bzoj3998 弦论

对于一个给定长度为 N 的字符串,求它的第 K 小子串是什么。

不同位置的相同子串可以算多个,也可以算一个

sol:后缀自动机,每个点搞一个权值,如果算多个,就是这个点 end-pos 集合大小,算一个,就是 1 

然后按字典序搜一下就可以了

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
void fre()
{
    freopen("mydata.in","r",stdin);
    freopen("mydata.out","w",stdout);
}
int n,t,k;
const int maxn = 1e6 + 10;
char s[maxn];
int tr[maxn][26],fa[maxn],mxlen[maxn],root,last,dfn;
int c[maxn],size[maxn],rk[maxn],sum[maxn];
void extend(int c)
{
    int p = last,np = last = ++dfn;
    mxlen[np] = mxlen[p] + 1;size[np] = 1;
    for(;p && !tr[p][c];p = fa[p])tr[p][c] = np;
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            mxlen[nq] = mxlen[p] + 1;
            fa[nq] = fa[q];
            memcpy(tr[nq],tr[q],sizeof(tr[q]));
            fa[np] = fa[q] = nq;
            for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq;
        }
    }
}
void build()
{
    for(int i=1;i<=dfn;i++)++c[mxlen[i]];
    for(int i=1;i<=n;i++)c[i] += c[i - 1];
    for(int i=dfn;i;i--)rk[c[mxlen[i]]--] = i;
    for(int i=dfn;i;i--)
    {
        if(t == 1)size[fa[rk[i]]] += size[rk[i]];
        else size[fa[rk[i]]] = 1;
    }
    size[1] = 0;
    for(int i=dfn;i;i--)
    {
        sum[rk[i]] = size[rk[i]];
        for(int j=0;j<26;j++)
            sum[rk[i]] += sum[tr[rk[i]][j]];
    }
}
void dfs(int x,int k)
{
    if(k <= size[x])return;
    k -= size[x];
    for(int i=0;i<26;i++)
    {
        if(!tr[x][i])continue;
        if(k <= sum[tr[x][i]])
        {
            putchar('a' + i);
            dfs(tr[x][i],k);
            return;
        }k -= sum[tr[x][i]];
    }
}
int main()
{
#ifdef Ez3real
    fre();
#endif
    root = last = ++dfn;
    scanf("%s",s + 1);n = strlen(s + 1);
    t = read(),k = read();
    for(int i=1;i<=n;i++)extend(s[i] - 'a');
    build();
    if(k > sum[1])puts("-1");
    else dfs(root,k);
}
View Code

bzoj4566 字符串

多次询问 $s[a,b]$ 的所有子串和 $s[c,d]$ 的所有子串的最长公共前缀的最大值

sol:建反串的后缀自动机,这样最长公共前缀就变成了最长公共后缀,对应的就是 LCA 的 mxlen

我们二分答案 $x$ ,先倍增找到 $d$ 在后缀树上的位置,然后维护一下 $d$ 的 $endpos$ 集合里有没有出现 $[a+x-1,b]$ 这一段子串即可

维护 $endpos$ 集合要线段树合并,然后我们发现好像 $c$ 是打酱油的。。。

要注意,线段树合并要新开一个节点,要不然会挂

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
void fre()
{
    freopen("mydata.in","r",stdin);
    freopen("mydata.out","w",stdout);
}
const int maxn = 200010;
int n,m;
char s[maxn];
int tr[maxn][26],fa[maxn],mxlen[maxn],rt,dfn,last;
int mp[maxn],pos[maxn];
void extend(int c)
{
    int p = last,np = last = ++dfn;
    mxlen[np] = mxlen[p] + 1;
    for(;p && !tr[p][c];p = fa[p])tr[p][c] = np;
    if(!p)fa[np] = rt;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;mxlen[nq] = mxlen[p] + 1; 
            memcpy(tr[nq],tr[q],sizeof(tr[nq]));
            fa[nq] = fa[q];
            fa[q] = fa[np] = nq;
            for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq;
        }
    }
}
int root[maxn << 1],ls[maxn << 6],rs[maxn << 6],val[maxn << 6],ToT;
inline void Insert(int &x,int l,int r,int pos)
{
    if(!x) x = ++ToT;
    if(l == r){val[x]++;return;}
    int mid = (l + r) >> 1;
    if(pos <= mid)Insert(ls[x],l,mid,pos);
    else Insert(rs[x],mid + 1,r,pos);
    val[x] = val[ls[x]] + val[rs[x]];
}
inline int merge(int x,int y)
{
    if(!x || !y)return x + y;
    val[x] += val[y];
    //if(!ls[x] && !rs[x])return x;
    ls[x] = merge(ls[x],ls[y]);
    rs[x] = merge(rs[x],rs[y]);
    return x;
    /*if(!x || !y)return x + y;
    int np = ++ToT;
    ls[np] = merge(ls[y],ls[x]);
    rs[np] = merge(rs[x],rs[y]);
    val[np] = val[ls[np]] + val[rs[np]];
    return np;*/
}
inline int query(int x,int l,int r,int L,int R)
{
    if(L <= l && r <= R)return val[x];
    int mid = (l + r) >> 1,ans = 0;
    if(L <= mid)ans += query(ls[x],l,mid,L,R);
    if(R > mid)ans += query(rs[x],mid + 1,r,L,R);
    return ans;
}
int first[maxn],to[maxn << 1],nx[maxn << 1],cnt;
int dep[maxn],anc[maxn][23];
inline void add(int u,int v){to[++cnt] = v;nx[cnt] = first[u];first[u] = cnt;}
inline void dfs(int x)
{
    for(int i=1;i<=22;i++)
        anc[x][i] = anc[anc[x][i - 1]][i - 1];
    for(int i=first[x];i;i=nx[i])
    {
        if(to[i] == anc[x][0])continue;
        anc[to[i]][0] = x;dep[to[i]] = dep[x] + 1;
        dfs(to[i]);
        root[x] = merge(root[x],root[to[i]]);
    }
}
inline int get_anc(int x,int k) // x zuxian 
{
    //if(!k)return 1;
    for(int i=22;~i;i--)
        if(mxlen[anc[x][i]] >= k) x = anc[x][i];
    return x;
}
bool chk(int mid,int l,int r,int pos)
{
    if(mid == 0)return 1;
    if(l > r)return 0;
    pos = get_anc(pos,mid);
    return query(root[pos],1,n,l,r);
}
int main()
{
#ifdef Ez3real
    fre();
#endif
    rt = last = ++dfn;
    n = read(),m = read();
    scanf("%s",s + 1);
    reverse(s + 1,s + n + 1);
    for(int i=1;i<=n;i++)
    {
        extend(s[i] - 'a');
        mp[last] = i;
        pos[i] = last;
    }
    for(int i=1;i<=dfn;i++)
        if(mp[i])Insert(root[i],1,n,mp[i]);
    for(int i=1;i<=dfn;i++)add(fa[i],i);
    dep[rt] = 1;dfs(rt);
    while(m--)
    {
        int a = read(),b = read(),c = read(),d = read();
        swap(a,b);swap(c,d);a = n - a + 1,b = n - b + 1,c = n - c + 1,d = n - d + 1;
        int l = 0,r = min(b - a + 1,d - c + 1),ans = 0;
        //if(a > b || c > d){puts("0");continue;}
        while(l <= r)
        {
            int mid = (l + r) >> 1;
            if(chk(mid,a + mid - 1,b,pos[d]))l = mid + 1,ans = max(ans,mid);
            else r = mid - 1;
        }
        printf("%d
",ans);
    }
}
在 bzoj 上会 MLE,loj 过了

bzoj3926 诸神眷顾的幻想乡

一棵不超过 19 叉的树,每个点有一个颜色,颜色总共只有 10 种,树的大小一共只有 2000

对于任意两个树上的点 $(a,b)$ 我们称 $str_{(a,b)}$ 为从 $a$ 开始沿简单路径走到 $b$ 途径的每个点的颜色组成的序列

求有多少本质不同的 $str_{(a,b)}$

sol:Trie 树的广义后缀自动机

两点间的有向字符串可以视为以每个叶子节点为根构成的 Trie 树上的某条直链(祖先 -> 儿子)

对于这种问题我们可以建立每个 Trie 树的广义后缀自动机

因为叶子不超过 20 个,暴力即可

最后统计一下这个广义后缀自动机上有多少本质不同的子串,这就是模板题啦

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 2010000;
int n,c;LL ans;
int col[maxn];
int first[maxn],to[maxn],nx[maxn],cnt;
int ind[maxn];
inline void add(int u,int v)
{
    to[++cnt] = v;
    nx[cnt] = first[u];
    first[u] = cnt;
}
int last,root,dfn;
int fa[maxn],mxlen[maxn],tr[maxn][15];
void extend(int c)
{
    int p = last,np = last = ++dfn;
    mxlen[np] = mxlen[p] + 1;
    while(p && !tr[p][c])tr[p][c] = np,p = fa[p];
    if(!p)fa[np] = root;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)fa[np] = q;
        else
        {
            int nq = ++dfn;
            mxlen[nq] = mxlen[p] + 1;memcpy(tr[nq],tr[q],sizeof(tr[nq]));fa[nq] = fa[q],fa[np] = fa[q] = nq;
            while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p];
        }
    }
}
void dfs(int x,int fa)
{
    extend(col[x]);
    int tmp = last;
    for(int i=first[x];i;i=nx[i])
        if(to[i] != fa)
        {
            last = tmp;
            dfs(to[i],x);
        }
}
int main()
{
    root = last = ++dfn;
    n = read(),c = read();
    for(int i=1;i<=n;i++)col[i] = read();
    for(int i=2;i<=n;i++)
    {
        int u = read(),v = read();
        add(u,v);add(v,u);
        ind[u]++;ind[v]++;
    }
    for(int i=1;i<=n;i++)
        if(ind[i] == 1)
        {
            last = 1;
            dfs(i,0);
        }
    for(int i=1;i<=dfn;i++)ans += (mxlen[i] - mxlen[fa[i]]);
    cout<<ans;
}
View Code

loj6401 字符串

有一个字符串 $S$,每个位置可能是好的或者坏的,定义一个子串是好的,当且仅当它包含了不超过 $k$ 个坏的位置,求有多少本质不同的好的子串

$|S| leq 100000$

sol:每个子串找出最长的合法后缀,沿 parent 更新

#include<bits/stdc++.h>
using namespace std;
const int N=100010;
int K;
char s[N],b[N];
struct Suffix_Automaton{
    static const int M=N<<1;
    int son[M][26],par[M],Mxlen[M],SAM_cnt,Deg[M],Q[M],limit[M];
    int Extend(int p,int c){
        int q=++SAM_cnt;
        Mxlen[q]=Mxlen[p]+1;
        while (p>0 && son[p][c]==0){
            son[p][c]=q;
            p=par[p];
        }
        if (p==0){
            par[q]=1;
        }else{
            int r=son[p][c];
            if (Mxlen[r]==Mxlen[p]+1){
                par[q]=r;
            }else{
                int o=++SAM_cnt;
                par[o]=par[r];
                par[q]=par[r]=o;
                Mxlen[o]=Mxlen[p]+1;
                memcpy(son[o],son[r],sizeof(son[o]));
                while (p>0 && son[p][c]==r){
                    son[p][c]=o;
                    p=par[p];
                }
            }
        }
        return q;
    }
    void build(){
        int i,len=strlen(s),p=SAM_cnt=1,l=0,cnt=0;
        for (i=0;i<len;i++){
            p=Extend(p,s[i]-'a');
            cnt+=(b[i]=='0');
            while (cnt>K){
                cnt-=(b[l]=='0');
                l++;
            }
            limit[p]=i-l+1;
        }
        for (i=2;i<=SAM_cnt;i++){
            Deg[par[i]]++;
        }
        int L=1,R=0;
        for (i=1;i<=SAM_cnt;i++){
            if (Deg[i]==0){
                Q[++R]=i;
            }
        }
        long long Ans=0;
        while (L<=R){
            int x=Q[L++],t=par[x];
            Ans+=max(0,min(limit[x],Mxlen[x])-Mxlen[t]);
            if (t!=0){
                limit[t]=max(limit[t],limit[x]);
                Deg[t]--;
                if (Deg[t]==0){
                    Q[++R]=t;
                }
            }
        }
        printf("%lld
",Ans);
    }
}SAM;
int main(){
    scanf("%s%s%d",s,b,&K);
    SAM.build();
    return 0;
}
View Code

loj6041 事情的相似度

一个 01 串,多次询问一段区间内的前缀的最长公共后缀

$n,q leq 100000$

sol:

实质上是要求区间内两两 LCA 深度的最大值

离线,按右端点排序

每加入一个字母就在这个字母到根的路径上打标记

查询的时候沿查询节点往根跑,如果跑到一个有旧标记的点,则该点为旧标记和新标记的 LCA

每次贪心地更新更大的标记

用一个以询问左端点为下标的树状数组统计答案

然后发现从一个地方走到根这个事情复杂度不是很显然

写一个 LCT ,access 即可

#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read()
{
    int x = 0,f = 1;char ch = getchar();
    for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
    for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 200010;
int n,q,l[maxn],reh[maxn];
char s[maxn];
int root,dfn,last;
int tr[maxn][2],par[maxn],mxlen[maxn];
void extend(int c,int id)
{
    int p = last,np = last = ++dfn;
    reh[id] = np;
    mxlen[np] = mxlen[p] + 1;
    for(;p && !tr[p][c];p = par[p])tr[p][c] = np;
    if(!p)par[np] = root;
    else
    {
        int q = tr[p][c];
        if(mxlen[q] == mxlen[p] + 1)par[np] = q;
        else
        {
            int nq = ++dfn;
            mxlen[nq] = mxlen[p] + 1;
            memcpy(tr[nq],tr[q],sizeof(tr[nq]));
            par[nq] = par[q]; par[q] = par[np] = nq;
            for(;p && tr[p][c] == q;p = par[p])tr[p][c] = nq;
        }
    }
}
vector<int> qs[maxn];
int c[maxn];
inline int lowbit(int x){return x & (-x);}
inline int ask(int x){x = n - x + 1;int res = 0;for(;x;x -= lowbit(x))res = max(res,c[x]);return res;}
inline void add(int x,int val){x = n - x + 1;for(;x <= n;x += lowbit(x))c[x] = max(c[x],val);}
#define ls ch[x][0]
#define rs ch[x][1]
int ch[maxn][2],fa[maxn],val[maxn],tag[maxn],rev[maxn],st[maxn],top;
inline void pushdown(int x)
{
    if(!tag[x])return;
    if(ls)val[ls] = tag[ls] = tag[x];
    if(rs)val[rs] = tag[rs] = tag[x];
    tag[x] = 0;
} 
inline int isroot(int x){return (ch[fa[x]][0] != x) && (ch[fa[x]][1] != x);}
inline void rotate(int x)
{
    int y = fa[x],z = fa[y];
    int l = (ch[y][1] == x),r = l ^ 1;
    if(!isroot(y))ch[z][ch[z][1] == y] = x;
    fa[ch[x][r]] = y;fa[x] = z;fa[y] = x;
    ch[y][l] = ch[x][r];ch[x][r] = y;
    //pushup(y);pushup(x);
}
inline void splay(int x)
{
    st[top = 1] = x;
    for(int i=x;!isroot(i);i=fa[i])st[++top] = fa[i];
    for(int i=top;i;i--)pushdown(st[i]);
    while(!isroot(x))
    {
        int y = fa[x],z = fa[y];
        if(!isroot(y))
        {
            if(ch[z][1] == y ^ ch[y][1] == x)rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}
inline void access(int x,int v)
{
    int y;
    for(y=0;x;y = x,x = fa[x])splay(x),add(val[x],mxlen[x]),rs = y;
    tag[y] = val[y] = v;
}
int ans[maxn];
int main()
{
    last = root = ++dfn;
    n = read(),q = read();
    scanf("%s",s + 1);
    for(int i=1;i<=q;i++)
    {
        l[i] = read();int r = read();
        qs[r].push_back(i); 
    }
    for(int i=1;i<=n;i++)extend(s[i] - '0',i);
    for(int i=1;i<=dfn;i++)fa[i] = par[i];
    for(int i=1;i<=n;i++)
    {
        access(reh[i],i);
        int m = qs[i].size();
        for(int j=0;j<m;j++)
        {
            int now = qs[i][j];
            ans[now] = ask(l[now]);
        }
    }
    for(int i=1;i<=q;i++)printf("%d
",ans[i]);
}
View Code

upd:有一个好东西叫做序列自动机,也就是兹磁识别一个串的所有子序列的自动机

具体实现的话很简单,用一个数组 $next_{(i,j)}$ 记录第 $i$ 位后出现的第一个字符 $j$ 出现的位置即可

这样很多在串上的题可以强行上序列。。。

luogu P4608 所有公共子序列问题

求两个串所有公共子序列的个数,位置不同算不同

sol:建一个序列自动机然后暴力 dp 即可

#include<bits/stdc++.h>
#define MAXN 3002
#define BASE 1e9
using namespace std;
inline int get(char s){return s<='Z'? s-'A':s-'a'+26;}
inline char code(int x){return x<26? 'A'+x:'a'+x-26;}
struct SEGAM
{
    int t[MAXN][52],S,tot,f[52],last[MAXN];
    SEGAM()
    {
        S=tot=1;
        for(int i=0;i<52;i++)f[i]=S; 
    } 
    void Insert(int x)
    {
        last[++tot]=f[x];
        int i,j;
        for(i=0;i<52;i++)
        {
            for(j=f[i];j&&!t[j][x];j=last[j])t[j][x]=tot;
        }
        f[x]=tot;
    }
}A,B;
struct BigNum
{
    int a[20],n;
    void print()
    {
        printf("%d",a[n-1]);
        for(int i=n-2;i>=0;i--)printf("%09d",a[i]);
    }
}dp[MAXN][MAXN],one,zero;
inline BigNum operator + (BigNum x,BigNum y)
{
    x.n=max(x.n,y.n);
    for(int i=0;i<x.n;i++)
    {
        x.a[i]+=y.a[i];
        if(x.a[i]>=BASE)x.a[i+1]++,x.a[i]-=BASE;
    }
    if(x.a[x.n])x.n++;
    return x;
}
bitset<MAXN> vis[MAXN]; 
BigNum Solve1(int a,int b)
{
    if(!a||!b)return zero;
    if(vis[a][b])return dp[a][b];
    vis[a][b]=1;dp[a][b]=one;
    for(int i=0;i<52;i++)dp[a][b]=dp[a][b]+Solve1(A.t[a][i],B.t[b][i]);
    return dp[a][b];
}
char s[MAXN];
void Solve2(int a,int b,int n)
{
    if(!a||!b)return;
    s[n]=0;printf("%s
",s);
    for(int i=0;i<52;i++)
    {
        s[n]=code(i);
        Solve2(A.t[a][i],B.t[b][i],n+1);
    }
}
int N,M,K;
char sa[MAXN],sb[MAXN];
int main()
{
    scanf("%d%d%s%s%d",&N,&M,sa,sb,&K);
    one.n=one.a[0]=zero.n=1;
    for(int i=0;i<N;i++)A.Insert(get(sa[i]));
    for(int i=0;i<M;i++)B.Insert(get(sb[i]));
    if(K)Solve2(1,1,0);
    Solve1(1,1);dp[1][1].print();
    return 0;
} 
View Code
原文地址:https://www.cnblogs.com/Kong-Ruo/p/9985129.html