可持久化线段树(主席树)

定义

  可持久化线段树是可以保留历史版本的线段树,相当于保留了每次修改后的线段树,并且可以对每次修改后的结果进行查询

主要思想

  对于询问历史版本的问题,我们对每次修改都新建一棵线段树
  但是如果修改次数特别多,这样肯定是不可行的,这就要用到可持久化线段树了

  我们通过观察可以发现,对于线段树的每次修改,不论是单点还是区间,所改变的节点都是log级别的
  于是我们所建的线段树只新建那log级别改变的节点,不变的节点直接指向历史版本就行了

具体实现

    • 建树

      和普通线段树基本相同

      int build(int L,int R){//建树 
          int k=size++;
          if(L==R){sum[k]=a[L-1];return k;}
          int mid=(L+R)>>1;
          l[k]=build(L,mid),r[k]=build(mid+1,R);
          sum[k]=sum[l[k]]+sum[r[k]];
          return k;
      }
    • 区间修改

      对于所有的修改,都将原节点复制一遍,然后在新建的节点上修改
      对于没有修改的子节点就直接指向上个版本的节点即可

      int modify(int history,int L,int R,int w,int l1,int r1){//区间修改,history是上个版本的当前节点的编号,L,R为当前节点范围,l1,l2为要修改的范围 
          int k=size++;
          lazy[k]=lazy[history],sum[k]=sum[history]+1ll*(r1-l1+1)*w,l[k]=l[history],r[k]=r[history];//克隆节点 
          if(l1<=L&&r1>=R){//整个区间都要被修改,打个标记返回 
              lazy[k]+=w;return k;
          }
          int mid=(L+R)>>1;
          if(l1<=mid)l[k]=modify(l[history],L,mid,w,l1,min(mid,r1));
          if(r1>mid)r[k]=modify(r[history],mid+1,R,w,max(mid+1,l1),r1);
          return k;
      }
      int change(int history,int L,int R,int w,int x){//单点修改 
          int k=size++;
          lazy[k]=lazy[history],sum[k]=sum[history]+w,l[k]=l[history],r[k]=r[history];//克隆节点
          if(L==R)return k;//如果是叶子节点就直接返回 
          int mid=(L+R)>>1;
          if(x<=mid)l[k]=change(l[history],L,mid,w,x);
          else r[k]=change(r[history],mid+1,R,w,x);
          return k;
      }
    • 区间查询

      和普通线段树区间查询一样,只是lazy标记不能下放(因为子节点是共用的),而是统计lazy标记造成的影响

      int query(int x,int L,int R,int l1,int r1){//询问 
          if(l1<=L&&r1>=R)return sum[x];
          int mid=(L+R)>>1;
          int ans=lazy[x]*(r1-l1+1);//标记不能下放,因为子节点也是共用的,只能统计标记造成的影响 
          if(l1<=mid)ans+=query(l[x],L,mid,l1,min(mid,r1));
          if(r1>mid)ans+=query(r[x],mid+1,R,max(mid+1,l1),r1);
          return ans;
      }

模板

    • 非动态开点

      #include<cstdio>
      #include<algorithm>
      using namespace std;
      #define maxn 100005
      int n,m,a[maxn],root[maxn],size,lazy[maxn*70],l[maxn*70],r[maxn*70],sum[maxn*70];//非动态开点则节点要开到n*logn*4 
      int build(int L,int R){//建树 
          int k=size++;
          if(L==R){sum[k]=a[L-1];return k;}
          int mid=(L+R)>>1;
          l[k]=build(L,mid),r[k]=build(mid+1,R);
          sum[k]=sum[l[k]]+sum[r[k]];
          return k;
      }
      int modify(int history,int L,int R,int w,int l1,int r1){//区间修改,history是上个版本的当前节点的编号,L,R为当前节点范围,l1,l2为要修改的范围 
          int k=size++;
          lazy[k]=lazy[history],sum[k]=sum[history]+1ll*(r1-l1+1)*w,l[k]=l[history],r[k]=r[history];//克隆节点 
          if(l1<=L&&r1>=R){//整个区间都要被修改,打个标记返回 
              lazy[k]+=w;return k;
          }
          int mid=(L+R)>>1;
          if(l1<=mid)l[k]=modify(l[history],L,mid,w,l1,min(mid,r1));
          if(r1>mid)r[k]=modify(r[history],mid+1,R,w,max(mid+1,l1),r1);
          return k;
      }
      int change(int history,int L,int R,int w,int x){//单点修改 
          int k=size++;
          lazy[k]=lazy[history],sum[k]=sum[history]+w,l[k]=l[history],r[k]=r[history];//克隆节点
          if(L==R)return k;//如果是叶子节点就直接返回 
          int mid=(L+R)>>1;
          if(x<=mid)l[k]=change(l[history],L,mid,w,x);
          else r[k]=change(r[history],mid+1,R,w,x);
          return k;
      }
      int query(int x,int L,int R,int l1,int r1){//询问 
          if(l1<=L&&r1>=R)return sum[x];
          int mid=(L+R)>>1;
          int ans=lazy[x]*(r1-l1+1);//标记不能下放,因为子节点也是共用的,只能统计标记造成的影响 
          if(l1<=mid)ans+=query(l[x],L,mid,l1,min(mid,r1));
          if(r1>mid)ans+=query(r[x],mid+1,R,max(mid+1,l1),r1);
          return ans;
      }
      int main(){
          return 0;
      } 
    • 动态开点

      #include<cstdio>
      #include<algorithm>
      #include<cstring>
      using namespace std;
      #define maxn 100005
      struct Node{
          Node *l,*r;
          bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 
          int lazy,sum;
          Node(){
              memset(this,0,sizeof(Node));
          }
      }root[maxn];
      int n,m,a[maxn];
      void build(Node &x,int L,int R){//建树 
          if(L==R){x.sum=a[L-1];return;}
          int mid=(L+R)>>1;
          build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R);
          x.sum=x.l->sum+x.r->sum;
      }
      void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 
          x=history,x.sum+=(r1-l1+1)*w,x.r1=x.l1=0;
          if(l1<=L&&r1>=R){
              x.lazy+=w;return;
          }
          int mid=(L+R)>>1;
          if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1));
          if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1);
      }
      void change(Node &history,Node &x,int L,int R,int w,int l1){//单点修改 
          x=history,x.sum+=w,x.l1=x.r1=0;//克隆节点
          if(L==R)return;//如果是叶子节点就直接返回 
          int mid=(L+R)>>1;
          if(l1<=mid)change(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1);
          else change(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,l1);
      }
      int query(Node &x,int L,int R,int l1,int r1){//区间查询 
          if(l1<=L&&r1>=R)return x.sum;
          int mid=(L+R)>>1,ans=x.lazy*(r1-l1+1);
          if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1));
          if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1);
          return ans;
      }
      void remove_node(Node *x){//节点空间释放 
          if(x->l1)remove_node(x->l);
          if(x->r1)remove_node(x->r);
          delete x;
      }
      void remove(int n){//线段树空间释放 
          for(int i=0;i<=n;i++){
              if(root[i].l1)remove_node(root[i].l);
              if(root[i].r1)remove_node(root[i].r);
          }
          memset(root,0,sizeof(root));
      }
      int main(){
          return 0;
      } 

例题hdu4348.To the moon

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define maxn 100005
#define LL long long
struct Node{
    Node *l,*r;
    bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 
    int lazy;
    LL sum;
    Node(){
        memset(this,0,sizeof(Node));
    }
}root[maxn];
int n,m,a[maxn];
void build(Node &x,int L,int R){//建树 
    if(L==R){x.sum=a[L-1];return;}
    int mid=(L+R)>>1;
    build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R);
    x.sum=x.l->sum+x.r->sum;
}
void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 
    x=history,x.sum+=1ll*(r1-l1+1)*w,x.r1=x.l1=0;
    if(l1<=L&&r1>=R){
        x.lazy+=w;return;
    }
    int mid=(L+R)>>1;
    if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1));
    if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1);
}
LL query(Node &x,int L,int R,int l1,int r1){//区间查询 
    if(l1<=L&&r1>=R)return x.sum;
    int mid=(L+R)>>1;LL ans=x.lazy*(r1-l1+1);
    if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1));
    if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1);
    return ans;
}
void remove_node(Node *x){//节点空间释放 
    if(x->l1)remove_node(x->l);
    if(x->r1)remove_node(x->r);
    delete x;
}
void remove(int n){//线段树空间释放 
    for(int i=0;i<=n;i++){
        if(root[i].l1)remove_node(root[i].l);
        if(root[i].r1)remove_node(root[i].r);
    }
    memset(root,0,sizeof(root));
}
void work(){
    for(int i=0;i<n;i++)scanf("%d",a+i);
    build(root[0],1,n);
    int L,R,d,time=0;char s[2];
    for(int i=0;i<m;i++){
        scanf("%s",s);
        if(s[0]=='C'){
            scanf("%d%d%d",&L,&R,&d);
            time++;
            modify(root[time-1],root[time],1,n,d,L,R);
        }
        else if(s[0]=='Q'){
            scanf("%d%d",&L,&R);
            printf("%lld
",query(root[time],1,n,L,R));
        }
        else if(s[0]=='H'){
            scanf("%d%d%d",&L,&R,&d);
            printf("%lld
",query(root[d],1,n,L,R));
        }
        else{
            scanf("%d",&d);
            for(;time>d;time--){
                if(root[time].l1)remove_node(root[time].l);
                if(root[time].r1)remove_node(root[time].r);
            }
        }
    }
    remove(time);
}
int main(){
    while(~scanf("%d%d",&n,&m)){
        work();
    }
    return 0;
}

求区间第k大

  可持久化线段树有一个非常重要的用法——求区间第k大

 具体做法

  将原序列排序,开一个1-n的线段树,每个节点刚开始都为0
  然后按照原序列的顺序,每次找到原序列一个数在排好序的序列的位置,在线段树相应位置+1

  这时,第i个线段树维护的就是[1,i]在排好序的序列中的位置,我们可以快速求出[1,i]的第k大
  即如果左子节点的值x>=k,那么这个数就是左子节点的第k大,否则是右子节点的第k-x大

  如果求[l,r]第k大,只需要对第r棵线段树和第l-1棵作差,然后按照上面的方法即可

int query_k(Node &x,Node &y,int L,int R,int k){//询问x线段树-y线段树中的第k大 
    if(L==R)return L;
    int num=x.l->sum-y.l->sum,mid=(L+R)>>1;
    if(num>=k)return query_k(*x.l,*y.l,L,mid,k);
    else return query_k(*x.r,*y.r,mid+1,R,k-num);
}

 例题luoguP3834 【模板】可持久化线段树 1(主席树)

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define maxn 200005
struct Node{
    Node *l,*r;
    bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 
    int lazy,sum;
    Node(){
        memset(this,0,sizeof(Node));
    }
}root[maxn];
int n,m,a[maxn],b[maxn],st;
void build(Node &x,int L,int R){//建树 
    if(L==R){x.sum=a[L-1];return;}
    int mid=(L+R)>>1;
    build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R);
    x.sum=x.l->sum+x.r->sum;
}
void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 
    x=history,x.sum+=(r1-l1+1)*w,x.r1=x.l1=0;
    if(l1<=L&&r1>=R){
        x.lazy+=w;return;
    }
    int mid=(L+R)>>1;
    if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1));
    if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1);
}
void change(Node &history,Node &x,int L,int R,int w,int l1){//单点修改 
    x=history,x.sum+=w,x.l1=x.r1=0;//克隆节点
    if(L==R)return;//如果是叶子节点就直接返回 
    int mid=(L+R)>>1;
    if(l1<=mid)change(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1);
    else change(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,l1);
}
int query(Node &x,int L,int R,int l1,int r1){//区间查询 
    if(l1<=L&&r1>=R)return x.sum;
    int mid=(L+R)>>1,ans=x.lazy*(r1-l1+1);
    if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1));
    if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1);
    return ans;
}
int query_k(Node &x,Node &y,int L,int R,int k){//询问x-y线段树中的第k大 
    if(L==R)return L;
    int num=x.l->sum-y.l->sum,mid=(L+R)>>1;
    if(num>=k)return query_k(*x.l,*y.l,L,mid,k);
    else return query_k(*x.r,*y.r,mid+1,R,k-num);
}
void remove_node(Node *x){//节点空间释放 
    if(x->l1)remove_node(x->l);
    if(x->r1)remove_node(x->r);
    delete x;
}
void remove(int n){//线段树空间释放 
    for(int i=0;i<=n;i++){
        if(root[i].l1)remove_node(root[i].l);
        if(root[i].r1)remove_node(root[i].r);
    }
    memset(root,0,sizeof(root));
}
int main(){
    int n,m;scanf("%d%d",&n,&m);
    for(int i=0;i<n;i++)scanf("%d",a+i),b[i]=a[i];
    sort(b,b+n),st=unique(b,b+n)-b;
    build(root[0],1,st);
    for(int i=0;i<n;i++){
        int x=lower_bound(b,b+st,a[i])-b;
        change(root[i],root[i+1],1,st,1,x+1);
    }
    int L,R,k;
    for(int i=0;i<m;i++){
        scanf("%d%d%d",&L,&R,&k);
        printf("%d
",b[query_k(root[R],root[L-1],1,st,k)-1]);
    }
    remove(n); 
    return 0;
}
原文地址:https://www.cnblogs.com/bennettz/p/8351894.html