[BZOJ2119] 股市的预测

Description

给你一个长度为 (n) 的数组 (a[]),在差分数组里面找 (ABA) 的形式,这里 (B) 的长度要求为 (m),找这样的连续段的个数。

Solution

枚举每种 (A) 长度 (i),则每隔 (i) 设置一个关键点,在这些关键点上求出 LCS 和 LCP,则我们会得到一段合法的区间

于是我们得到这段这个关键点对答案的贡献为 (lenlcp + lenlcs -i)

由于在 (len(A)=i) 的情况下,有且仅有对 (j) 这个关键点的统计中,(A_{left}) 是跨过 (j) 的,因此统计不重不漏

为了求 LCP 和 LCS,对差分序列正反各建 SAM,用 ST 表维护后缀树支持 LCA 询问,根据调和级数,时间复杂度为 (O(n log n))

(其实明显应该写 SA 嘛,为了尝试用 SAM 搞 LCP 才写的)

#include <bits/stdc++.h>
using namespace std;

const int N = 600005;
const int lgn = 19;

const int dbg = 0;

struct lcasolver
{
    int rt,st[N][lgn],dep[N],dis[N],vis[N],ind,lg2[N],s[N],bg[N],ed[N];
    vector <int> g[N];

    lcasolver()
    {
        memset(st,0,sizeof st);
        memset(dep,0,sizeof dep);
        memset(vis,0,sizeof vis);
        memset(s,0,sizeof s);
        memset(bg,0,sizeof bg);
        memset(ed,0,sizeof ed);
    }

    void dfs(int p)
    {
        vis[p]=1;
        s[++ind]=p;
        bg[p]=ind;
        for(int q:g[p])
        {
            if(vis[q]==0)
            {
                dep[q]=dep[p]+1;
                dfs(q);
                s[++ind]=p;
            }
        }
        ed[p]=ind;
    }

    int lca(int p,int q)
    {
        p=bg[p];
        q=bg[q];
        if(p>q) swap(p,q);
        int l=lg2[q-p+1];
        int x=st[p][l];
        int y=st[q-(1<<l)+1][l];
        return s[dis[x]<dis[y]?x:y];
    }

    void solve()
    {
        rt=1;
        for(int i=0; i<=lgn; i++) for(int j=1<<i; j<1<<(i+1); j++) lg2[j]=i;
        dfs(rt);
        for(int i=1; i<=ind; i++) dis[i]=dep[s[i]];
        for(int i=1; i<=ind; i++) st[i][0]=i;
        for(int i=1; i<lgn; i++)
        {
            for(int j=1; j<=ind; j++)
            {
                st[j][i]=dis[st[j][i-1]]<dis[st[j+(1<<(i-1))][i-1]]?st[j][i-1]:st[j+(1<<(i-1))][i-1];
            }
        }
    }

    void make(int p,int q)
    {
        if(dbg) cout<<" makeedge "<<p<<" "<<q<<endl;
        g[p].push_back(q);
        g[q].push_back(p);
    }
} lca1,lca2;

struct SAM
{
    int len[N], fa[N], ind, last;
    map<int,int> ch[N];
    int t[N], a[N], cnt[N], f[N], ep[N], tot;
    void clear()
    {
        ind = last = 1;
        tot = 0;
        memset(len,0,sizeof len);
        memset(fa,0,sizeof fa);
        memset(t,0,sizeof t);
        memset(a,0,sizeof a);
        memset(cnt,0,sizeof cnt);
        memset(f,0,sizeof f);
        for(int i=1; i<=ind; i++) ch[i].clear();
    }
    SAM()
    {
        clear();
    }
    inline void extend(int id)
    {
        int cur = (++ ind), p;
        len[cur] = len[last] + 1;
        cnt[cur] = 1;
        for (p = last; p && !ch[p][id]; p = fa[p]) ch[p][id] = cur;
        if (!p) fa[cur] = 1;
        else
        {
            int q = ch[p][id];
            if (len[q] == len[p] + 1) fa[cur] = q;
            else
            {
                int tmp = (++ ind);
                len[tmp] = len[p] + 1;
                ch[tmp] = ch[q];
                fa[tmp] = fa[q];
                for (; p && ch[p][id] == q; p = fa[p]) ch[p][id] = tmp;
                fa[cur] = fa[q] = tmp;
            }
        }
        last = cur;
        ep[++tot] = last;
    }
    void calcEndpos()
    {
        memset(t, 0, sizeof t);
        for(int i=1; i<=ind; i++) t[len[i]]++;
        for(int i=1; i<=ind; i++) t[i]+=t[i-1];
        for(int i=1; i<=ind; i++) a[t[len[i]]--]=i;
        for(int i=ind; i>=1; --i) cnt[fa[a[i]]]+=cnt[a[i]];
        cnt[1] = 0;
    }
    void solve(lcasolver& tr)
    {
        calcEndpos();
        for(int i=1;i<=ind;i++)
        {
            if(fa[i]!=0) tr.make(fa[i],i);
        }
        tr.solve();
    }
    int query(lcasolver& tr, int p, int q)
    {
        int l = tr.lca(ep[p],ep[q]);
        return len[l];
    }
} sam1,sam2;

int n,m,a[N],b[N];

int getlcs(int p,int q)
{
    return sam1.query(lca1,p,q);
}

int getlcp(int p,int q)
{
    return sam2.query(lca2,n-p,n-q);
}

signed main()
{
    ios::sync_with_stdio(false);
    cin>>n>>m;
    for(int i=0;i<n;i++) cin>>a[i], b[i]=a[i]-a[i-1];
    for(int i=1;i<n;i++) sam1.extend(b[i]);
    for(int i=n-1;i;i--) sam2.extend(b[i]);
    sam1.solve(lca1);
    sam2.solve(lca2);

    long long ans=0;
    for(int i=1;i<n;i++)
    {
        for(int j=1;j<n;j+=i)
        {
            int p=j, q=j+i+m;
            if(q>=n) break;
            int lcs=getlcs(p,q), lcp=getlcp(p,q);
            lcs=min(lcs,i);
            lcp=min(lcp,i);
            if(lcs+lcp-1>=i)
            {
                ans+=lcs+lcp-i;
            }
        }
    }

    cout<<ans<<endl;
}

原文地址:https://www.cnblogs.com/mollnn/p/13283724.html