[CF653F] Paper task

[CF653F] Paper task

Description

给定一个括号序列,统计合法的本质不同子串的个数。

Solution

很容易想到,只要在传统统计本质不同子串的基础上修改一下即可。

考虑经典统计过程,对于第 (i) 个后缀,它的贡献为 (n - sa[i] + 1 - h[i])

也就意味着,它产生贡献的区间是 ([sa[i]+h[i], n]) 。换言之,对任意 (j in [sa[i]+h[i], n])(s[sa[i],j]) 是一个答案。

那么我们现在就是要判断这些答案中有多少合法。也就是对某个 (i) ,有多少个 (j in [sa[i]+h[i], n]) , ,满足 (sum_{k=sa[i]}^j a_k = sum[j]-sum[sa[i]-1])(0)。 这里 (a_i) 表示括号序列,左括号对应 (1) ,右括号对应 (-1)

也就是询问下标在 (sa[i]+h[i]) 及之后, (sum[k]=sum[sa[i]-1])(k) 有多少个。

按数值插进若干个 std::vector 然后暴力二分即可。

但这样会忽略那些中途出现右括号比左括号多的情况。因此,对于每个后缀,我们在统计时要找到最远能到达的位置,即第一个小于等于 (sum[sa[i]-1]-1) 出现的位置,我们要把这个位置之后的结果减去。暴力扫一遍,权值线段树维护即可。

#include <bits/stdc++.h>
using namespace std;

int n,m=256,sa[1000005],y[1000005],u[1000005],v[1000005],o[1000005],r[1000005],h[1000005],T;
int a[1000005],sum[1000005],lim[1000005];
char str[1000005];
long long ans;
vector <int> vec[1000005];

namespace seg {
    int val[4000005];
    void build(int p,int l,int r) {
        if(l==r) {
            val[p]=INT_MAX;
        }
        else {
            build(p*2,l,(l+r)/2);
            build(p*2+1,(l+r)/2+1,r);
            val[p]=min(val[p*2],val[p*2+1]);
        }
    }
    void modify(int p,int l,int r,int pos,int key) {
        if(l==r) {
            val[p]=key;
        }
        else {
            if(pos<=(l+r)/2) modify(p*2,l,(l+r)/2,pos,key);
            else modify(p*2+1,(l+r)/2+1,r,pos,key);
            val[p]=min(val[p*2],val[p*2+1]);
        }
    }
    int query(int p,int l,int r,int ql,int qr) {
        if(l>qr || r<ql) return INT_MAX;
        if(l>=ql && r<=qr) return val[p];
        return min(query(p*2,l,(l+r)/2,ql,qr),query(p*2+1,(l+r)/2+1,r,ql,qr));
    }
    void modify(int pos,int key) {
        modify(1,1,2*n+1,pos+n+1,key);
    }
    int query(int ql,int qr) {
        return query(1,1,2*n+1,ql+n+1,qr+n+1);
    }
    void build() {
        build(1,1,2*n+1);
    }
}

int calc(int val,int pos) {
    return vec[val+500000].end() - lower_bound(vec[val+500000].begin(),vec[val+500000].end(),pos);
}

int main(){
    cin>>n;
    cin>>str+1;

    seg::build();

    for(int i=1;i<=n;i++) a[i]=str[i]=='('?1:-1;
    for(int i=1;i<=n;i++) sum[i]=sum[i-1]+a[i];

    for(int i=1;i<=n;i++) vec[sum[i]+500000].push_back(i);

    for(int i=n;i>=1;--i) {
        seg::modify(sum[i],i);
        lim[i]=min(n,seg::query(-n,sum[i-1]-1));
    }

    for(int i=1;i<=n;i++) u[str[i]]++;
    for(int i=1;i<=m;i++) u[i]+=u[i-1];
    for(int i=n;i>=1;i--) sa[u[str[i]]--]=i;
    r[sa[1]]=1;
    for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);

    for(int l=1;r[sa[n]]<n;l<<=1) {
        memset(u,0,sizeof u);
        memset(v,0,sizeof v);
        memcpy(o,r,sizeof r);
        for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++;
        for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1];
        for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i;
        for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i];
        r[sa[1]]=1;
        for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l]));
    }
    {
        int i,j,k=0;
        for(int i=1;i<=n;h[r[i++]]=k)
            for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++);
    }

    for(int i=1;i<=n;i++) ans+=calc(sum[sa[i]-1],sa[i]+h[i])-calc(sum[sa[i]-1],max(sa[i]+h[i],lim[sa[i]]+1));
    cout<<ans<<endl;
}
原文地址:https://www.cnblogs.com/mollnn/p/11813464.html