[CF653F] Paper task
Description
给定一个括号序列,统计合法的本质不同子串的个数。
Solution
很容易想到,只要在传统统计本质不同子串的基础上修改一下即可。
考虑经典统计过程,对于第 (i) 个后缀,它的贡献为 (n - sa[i] + 1 - h[i])
也就意味着,它产生贡献的区间是 ([sa[i]+h[i], n]) 。换言之,对任意 (j in [sa[i]+h[i], n]) , (s[sa[i],j]) 是一个答案。
那么我们现在就是要判断这些答案中有多少合法。也就是对某个 (i) ,有多少个 (j in [sa[i]+h[i], n]) , ,满足 (sum_{k=sa[i]}^j a_k = sum[j]-sum[sa[i]-1]) 为 (0)。 这里 (a_i) 表示括号序列,左括号对应 (1) ,右括号对应 (-1) 。
也就是询问下标在 (sa[i]+h[i]) 及之后, (sum[k]=sum[sa[i]-1]) 的 (k) 有多少个。
按数值插进若干个 std::vector 然后暴力二分即可。
但这样会忽略那些中途出现右括号比左括号多的情况。因此,对于每个后缀,我们在统计时要找到最远能到达的位置,即第一个小于等于 (sum[sa[i]-1]-1) 出现的位置,我们要把这个位置之后的结果减去。暴力扫一遍,权值线段树维护即可。
#include <bits/stdc++.h>
using namespace std;
int n,m=256,sa[1000005],y[1000005],u[1000005],v[1000005],o[1000005],r[1000005],h[1000005],T;
int a[1000005],sum[1000005],lim[1000005];
char str[1000005];
long long ans;
vector <int> vec[1000005];
namespace seg {
int val[4000005];
void build(int p,int l,int r) {
if(l==r) {
val[p]=INT_MAX;
}
else {
build(p*2,l,(l+r)/2);
build(p*2+1,(l+r)/2+1,r);
val[p]=min(val[p*2],val[p*2+1]);
}
}
void modify(int p,int l,int r,int pos,int key) {
if(l==r) {
val[p]=key;
}
else {
if(pos<=(l+r)/2) modify(p*2,l,(l+r)/2,pos,key);
else modify(p*2+1,(l+r)/2+1,r,pos,key);
val[p]=min(val[p*2],val[p*2+1]);
}
}
int query(int p,int l,int r,int ql,int qr) {
if(l>qr || r<ql) return INT_MAX;
if(l>=ql && r<=qr) return val[p];
return min(query(p*2,l,(l+r)/2,ql,qr),query(p*2+1,(l+r)/2+1,r,ql,qr));
}
void modify(int pos,int key) {
modify(1,1,2*n+1,pos+n+1,key);
}
int query(int ql,int qr) {
return query(1,1,2*n+1,ql+n+1,qr+n+1);
}
void build() {
build(1,1,2*n+1);
}
}
int calc(int val,int pos) {
return vec[val+500000].end() - lower_bound(vec[val+500000].begin(),vec[val+500000].end(),pos);
}
int main(){
cin>>n;
cin>>str+1;
seg::build();
for(int i=1;i<=n;i++) a[i]=str[i]=='('?1:-1;
for(int i=1;i<=n;i++) sum[i]=sum[i-1]+a[i];
for(int i=1;i<=n;i++) vec[sum[i]+500000].push_back(i);
for(int i=n;i>=1;--i) {
seg::modify(sum[i],i);
lim[i]=min(n,seg::query(-n,sum[i-1]-1));
}
for(int i=1;i<=n;i++) u[str[i]]++;
for(int i=1;i<=m;i++) u[i]+=u[i-1];
for(int i=n;i>=1;i--) sa[u[str[i]]--]=i;
r[sa[1]]=1;
for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
for(int l=1;r[sa[n]]<n;l<<=1) {
memset(u,0,sizeof u);
memset(v,0,sizeof v);
memcpy(o,r,sizeof r);
for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++;
for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1];
for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i;
for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i];
r[sa[1]]=1;
for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l]));
}
{
int i,j,k=0;
for(int i=1;i<=n;h[r[i++]]=k)
for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++);
}
for(int i=1;i<=n;i++) ans+=calc(sum[sa[i]-1],sa[i]+h[i])-calc(sum[sa[i]-1],max(sa[i]+h[i],lim[sa[i]]+1));
cout<<ans<<endl;
}