【NOI OL #2】T2 子序列问题

题目链接

设$a_i$上一次出现的位置为$p_{a_i}$。设$g(r)=sumlimits_{l=1}^r f^2(l,r)$。

考虑移动右端点,每次累加相同右端点所有区间的答案,发现移动右端点时新增的$a_r$会使左端点属于$(p_{a_r},r]$的$f(l,r)$加一。

那么$$g(r)-g(r-1)=sum_{l=p_{a_r}+1}^r( (f(l,r)+1)^2-f(l,r) )=sum_l(2 imes f(l,r)+1)=2sum_lf(l,r)+r-p_{a_r}$$

我们需要一个数据结构,支持区间加,查询区间和。用树状数组能更好地避免被卡常。

关于支持用树状数组维护区间加以及查询区间和的式子:$$sum a_i=sum_{i=1}^nsum_{j=1}^i c_j=sum_{i=1}^n(n-i+1) imes c_i=(n+1) imessum_{i+1}^n c_i-i imessum_{i=1}^n c_i$$

则使用两个数组,一个是差分数组的树状数组,一个是差分数值乘上下标的树状数组。

注意:一个修改/查询操作(即加一趟$lowbit$)对应修改/查询差分数组的一个元素,而不是一组元素。谨记!

程序(100分): 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<queue>
#include<map>
#include<set>
#define IL inline
#define RG register
#define _1 first
#define _2 second
using namespace std;
typedef long long LL;
const int N=1e6;
const LL mod=1e9+7;

IL LL add(LL x,LL y){return (x+y)%mod;}
IL LL mul(LL x,LL y){return x*y%mod;}
IL LL mns(LL x,LL y){return (x+mod-y)%mod;}
IL void fadd(LL &x,LL y){x=add(x,y);}
IL void fmul(LL &x,LL y){x=mul(x,y);}

    int n,a[N+3],p[N+3];
    
IL int bsch(int l,int r,int x){
    int mid,ans;
    while(l<=r){
        mid=(l+r)>>1;
        if(x<=p[mid]){
            r=mid-1;    ans=mid;
        }
        else 
            l=mid+1;
        
    }
    return ans;
    
}
    
IL void discrete(){
    memcpy(p,a,sizeof a);
    sort(p+1,p+n+1);
    int m=1;
    for(int i=2;i<=n;i++)
    if(p[i]!=p[i-1])
        p[++m]=p[i];
    
    for(int i=1;i<=n;i++)
        a[i]=bsch(1,m,a[i]);
    
}

    LL c1[N+3],c2[N+3];

IL int lowbit(int x){
    return x&(-x);
}

IL void mdf(int p,LL x){
    for(int i=p;i<=n;i+=lowbit(i)){
        fadd(c1[i],x);
        fadd(c2[i],1LL*x*p);
    }
}

IL void mdf(int l,int r,LL x){
    mdf(l,x);    mdf(r+1,-x);
}

IL LL qry(int p){
    LL ret=0;
    for(int i=p;i;i-=lowbit(i))
        fadd(ret,mns( mul(p+1,c1[i]) , c2[i] ));
    return ret;
}

IL LL qry(int l,int r){
    return mns(qry(r),qry(l-1));
}

    LL sum,ans;

int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
        
    discrete();
    
    memset(c1,0,sizeof c1);
    memset(c2,0,sizeof c2);
    memset(p,0,sizeof p);
    sum=ans=0;
    for(int i=1;i<=n;i++){
        fadd(sum,add(mul(2LL,qry(p[a[i]]+1,i)),mns(i,p[a[i]])));
        fadd(ans,sum);
        mdf(p[a[i]]+1,i,1);
        p[a[i]]=i;
        
    }
    
    printf("%lld",ans);

    return 0;

}
View Code
原文地址:https://www.cnblogs.com/Hansue/p/12891569.html