关于主席树的入门,讲解和题单

主席树真是神仙操作啊……搞了好久才弄懂一点点QAQ

参考文章:https://www.cnblogs.com/zyf0163/p/4749042.html

     https://blog.csdn.net/creatorx/article/details/75446472

     https://blog.csdn.net/jerans/article/details/75807666

     http://www.cnblogs.com/zcysky/p/6832876.html

ps:本文章中的题目我都写过题解了,可以自己去找

 1.前言

据说主席树这个名字的由来呢,是因为创始人的名字缩写hjt与某位相同,然后他因为不会划分树于是自创了这一个数据结构。好强啊orz

主席树能实现什么操作呢?最经典的就是查询区间第k小了,其他的还有诸如树上路径第k小啦,带修改第k小啦之类的。以静态区间第k小为例

 2.定义

先贴一下某神犇对主席树的理解:所谓主席树呢,就是对原来的数列[1..n]的每一个前缀[1..i](1≤i≤n)建立一棵线段树,线段树的每一个节点存某个前缀[1..i]中属于区间[L..R]的数一共有多少个(比如根节点是[1..n],一共i个数,sum[root] = i;根节点的左儿子是[1..(L+R)/2],若不大于(L+R)/2的数有x个,那么sum[root.left] = x)。若要查找[i..j]中第k大数时,设某结点x,那么x.sum[j] - x.sum[i - 1]就是[i..j]中在结点x内的数字总数。而对每一个前缀都建一棵树,会MLE,观察到每个[1..i]和[1..i-1]只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为O(nlogn)。

然而没有什么用,因为感觉根本没看懂

然后来说说我自己的理解吧。如何求出一个区间内第k小呢?直接sort当然可以,但是复杂度爆表。于是我们可以换一个思路,能否将$[l,r]$之间出现过的数都建成线段树呢?设节点为$p$,区间为$[l,r]$,左儿子是$[l,mid]$,右儿子是$[mid+1,r]$

要查找第k大的话,先看左儿子里有多少个数(表示小于等于$mid$的数的个数),如果大于$k$,进左子树找,否则令$k-=左儿子数的个数$,进右子树找

先来考虑一个序列:3,2,1,4

建完树之后是这样的

然后要查第2大,一下子就能发现是2了

(上面画的可能不是很严谨,大家将就下)

但我们不可能对每一个区间都建一棵树,那样的话空间复杂度绝对爆炸

然后可以转化一下思路:前缀和

区间$[l,r]$中小于等于$mid$的数的个数,可以转换为$[1,r]$中小于等于$mid$的数的个数减去$[1,l-1]$中小于等于$mid$的数的个数

于是我们只要对每一个前缀建一棵树即可

然后空间复杂度还是爆炸

然而我们又发现,区间$[1,l-1]$的树和区间$[1,l]$的树最多只会有$log n$个节点不同(因为每次新插入一个节点最多只会更新$log n$个节点),有许多空间是可以重复利用的

只要能将这些空间重复利用起来,就可以解决空间的问题了

还是上面那个序列:3,2,1,4

一开始先建一棵空树,然后一个个把每一个节点加进去

如果要看图的话可以点这里

这个时候有人就要问了,万一序列的数字特别大呢?

当然是离散化

将这些所有值离散一下就行了,可以保证所有数在$1~n$之间

然而感觉讲太多也没啥用……上代码好了,有详细的注释

以区间第k小为例 洛谷p3834

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 200005
 4 using namespace std;
 5 inline int read(){
 6     #define num ch-'0'
 7     char ch;bool flag=0;int res;
 8     while(!isdigit(ch=getchar()))
 9     (ch=='-')&&(flag=true);
10     for(res=num;isdigit(ch=getchar());res=res*10+num);
11     (flag)&&(res=-res);
12     #undef num
13     return res;
14 }
15 int sum[N<<5],L[N<<5],R[N<<5];
16 int a[N],b[N],t[N];
17 int n,q,m,cnt=0;
18 int build(int l,int r){
19     int rt=++cnt;
20     //建树 
21     sum[rt]=0;
22     if(l<r){
23         int mid=(l+r)>>1;
24         L[rt]=build(l,mid);
25         R[rt]=build(mid+1,r);
26     }
27     return rt;
28 }
29 int update(int last,int l,int r,int x){
30     int rt=++cnt;
31     L[rt]=L[last],R[rt]=R[last],sum[rt]=sum[last]+1;
32     //先继承上一次的信息 
33     //L是左节点,R是右节点,sum是节点内数的个数 
34     if(l<r){
35         int mid=(l+r)>>1;
36         if(x<=mid) L[rt]=update(L[last],l,mid,x);
37         else R[rt]=update(R[last],mid+1,r,x);
38         //如果有需要更新的信息,更新
39         //可以发现每一次更新的节点最多只有log n个 
40     }
41     return rt;
42 }
43 int query(int u,int v,int l,int r,int k){
44     if(l>=r) return l;
45     int x=sum[L[v]]-sum[L[u]];
46     //查询操作 
47     int mid=(l+r)>>1;
48     if(x>=k) return query(L[u],L[v],l,mid,k);
49     else return query(R[u],R[v],mid+1,r,k-x);
50     //如果左节点个数大于等于k,进左子树找第k小
51     //否则进右子树 
52 }
53 int main(){
54     //freopen("testdata.in","r",stdin);
55     n=read(),q=read();
56     for(int i=1;i<=n;++i)
57     b[i]=a[i]=read();
58     sort(b+1,b+1+n);
59     m=unique(b+1,b+1+n)-b-1;
60     t[0]=build(1,m);
61     //先建一棵空树 
62     for(int i=1;i<=n;++i){
63         int k=lower_bound(b+1,b+1+m,a[i])-b;
64         //离散 
65         t[i]=update(t[i-1],1,m,k);
66         //然后每次在上一次的基础上建树 
67     }
68     while(q--){
69         int x,y,z;
70         x=read(),y=read(),z=read();
71         int k=query(t[x-1],t[y],1,m,z);
72         printf("%d
",b[k]);
73     }
74     return 0;
75 }
View Code

如果熟练了之后,可以发现其实第一步的建树过程是可以省略的,直接每一步加节点就行了

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 200005
 4 using namespace std;
 5 inline int read(){
 6     #define num ch-'0'
 7     char ch;bool flag=0;int res;
 8     while(!isdigit(ch=getchar()))
 9     (ch=='-')&&(flag=true);
10     for(res=num;isdigit(ch=getchar());res=res*10+num);
11     (flag)&&(res=-res);
12     #undef num
13     return res;
14 }
15 int sum[N<<5],L[N<<5],R[N<<5];
16 int a[N],b[N],t[N];
17 int n,q,m,cnt=0;
18 void update(int last,int &now,int l,int r,int x){
19     //注意这里开的是引用 
20     if(!now) now=++cnt;
21     sum[now]=sum[last]+1;
22     if(l==r) return;
23     int mid=(l+r)>>1;
24     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
25     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
26 }
27 int query(int u,int v,int l,int r,int k){
28     if(l>=r) return l;
29     int x=sum[L[v]]-sum[L[u]];
30     int mid=(l+r)>>1;
31     if(x>=k) return query(L[u],L[v],l,mid,k);
32     else return query(R[u],R[v],mid+1,r,k-x);
33 }
34 int main(){
35     //freopen("testdata.in","r",stdin);
36     n=read(),q=read();
37     for(int i=1;i<=n;++i)
38     b[i]=a[i]=read();
39     sort(b+1,b+1+n);
40     m=unique(b+1,b+1+n)-b-1;
41     for(int i=1;i<=n;++i){
42         int k=lower_bound(b+1,b+1+m,a[i])-b;
43         update(t[i-1],t[i],1,m,k);
44         //省略建树过程,直接加入节点 
45     }
46     while(q--){
47         int x,y,z;
48         x=read(),y=read(),z=read();
49         int k=query(t[x-1],t[y],1,m,z);
50         printf("%d
",b[k]);
51     }
52     return 0;
53 }
View Code

 还有一道板子题洛谷SP3946 poj2104 K-th Number

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 100005
 4 using namespace std;
 5 inline int read(){
 6     #define num ch-'0'
 7     char ch;bool flag=0;int res;
 8     while(!isdigit(ch=getchar()))
 9     (ch=='-')&&(flag=true);
10     for(res=num;isdigit(ch=getchar());res=res*10+num);
11     (flag)&&(res=-res);
12     #undef num
13     return res;
14 }
15 int sum[N<<5],L[N<<5],R[N<<5];
16 int a[N],b[N],rt[N];
17 int n,q,m,cnt=0;
18 void update(int last,int &now,int l,int r,int x){
19     sum[now=++cnt]=sum[last]+1;
20     if(l==r) return;
21     int mid=(l+r)>>1;
22     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
23     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
24 }
25 int query(int u,int v,int l,int r,int k){
26     if(l>=r) return l;
27     int x=sum[L[v]]-sum[L[u]];
28     int mid=(l+r)>>1;
29     if(x>=k) return query(L[u],L[v],l,mid,k);
30     else return query(R[u],R[v],mid+1,r,k-x);
31 }
32 int main(){
33     //freopen("testdata.in","r",stdin);
34     n=read(),q=read();
35     for(int i=1;i<=n;++i)
36     b[i]=a[i]=read();
37     sort(b+1,b+1+n);
38     m=unique(b+1,b+1+n)-b-1;
39     for(int i=1;i<=n;++i){
40         int k=lower_bound(b+1,b+1+m,a[i])-b;
41         update(rt[i-1],rt[i],1,m,k);
42     }
43     while(q--){
44         int x,y,z;
45         x=read(),y=read(),z=read();
46         int k=query(rt[x-1],rt[y],1,m,z);
47         printf("%d
",b[k]);
48     }
49     return 0;
50 }
View Code

还有一道题,也是主席树的一般应用 洛谷P3567 [POI2014]KUR-Couriers

出现次数可以转化为左右节点的大小,如果符合条件就递归

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 500005
 4 using namespace std;
 5 inline int read(){
 6     #define num ch-'0'
 7     char ch;bool flag=0;int res;
 8     while(!isdigit(ch=getchar()))
 9     (ch=='-')&&(flag=true);
10     for(res=num;isdigit(ch=getchar());res=res*10+num);
11     (flag)&&(res=-res);
12     #undef num
13     return res;
14 }
15 int sum[N*20],L[N*20],R[N*20],t[N];
16 int n,q,cnt=0;
17 void update(int last,int &now,int l,int r,int x){
18     if(!now) now=++cnt;
19     sum[now]=sum[last]+1;
20     if(l==r) return;
21     int mid=(l+r)>>1;
22     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
23     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
24 }
25 int query(int u,int v,int l,int r,int k){
26     if(l==r) return l;
27     int x=sum[L[v]]-sum[L[u]],y=sum[R[v]]-sum[R[u]];
28     int mid=(l+r)>>1;
29     if(x*2>k) return query(L[u],L[v],l,mid,k);
30     if(y*2>k) return query(R[u],R[v],mid+1,r,k);
31     return 0;
32 }
33 int main(){
34     //freopen("testdata.in","r",stdin);
35     n=read(),q=read();
36     for(int i=1;i<=n;++i){
37         int x=read();
38         update(t[i-1],t[i],1,n,x);
39     }
40     while(q--){
41         int x,y;
42         x=read(),y=read();
43         int k=query(t[x-1],t[y],1,n,y-x+1);
44         printf("%d
",k);
45     }
46     return 0;
47 }
View Code

然后区间静态第k大就解决了~(≧▽≦)/~啦啦啦

树上路径

有些题目会给你一棵树,问你树上两点间路径上的第k大

怎么解决呢?

可以发现,这个东西是可以进行差分的

比如说,$u$到$v$路径上的权值和,可以变成$sum[u]+sum[v]-sum[lca]-sum[lca_fa]$

然后套到主席树上,就是小于某个数的个数,同样也可以差分出来表示

但问题是主席树怎么建呢?

我们发现,因为要求lca,我们可以在树剖dfs的时候顺便加点

具体来说,就是用$fa[i]$的信息更新$i$点的信息

以bzoj2588 洛谷p2633. count on a tree为例

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 100005
 4 #define M 2000005
 5 using namespace std;
 6 inline int read(){
 7     #define num ch-'0'
 8     char ch;bool flag=0;int res;
 9     while(!isdigit(ch=getchar()))
10     (ch=='-')&&(flag=true);
11     for(res=num;isdigit(ch=getchar());res=res*10+num);
12     (flag)&&(res=-res);
13     #undef num
14     return res;
15 }
16 int sum[M],L[M],R[M];
17 int a[N],b[N],rt[N];
18 int fa[N],sz[N],d[N],ver[N<<1],Next[N<<1],head[N],son[N],top[N];
19 int n,q,m,cnt=0,tot=0,ans=0;
20 void update(int last,int &now,int l,int r,int x){
21     sum[now=++cnt]=sum[last]+1;
22     if(l==r) return;
23     int mid=(l+r)>>1;
24     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
25     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
26 }
27 inline void add(int u,int v){
28     ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
29     ver[++tot]=u,Next[tot]=head[v],head[v]=tot;
30 }
31 void dfs(int u){
32     sz[u]=1,d[u]=d[fa[u]]+1;
33     update(rt[fa[u]],rt[u],1,m,a[u]);
34     for(int i=head[u];i;i=Next[i]){
35         int v=ver[i];
36         if(v==fa[u]) continue;
37         fa[v]=u,dfs(v);
38         sz[u]+=sz[v];
39         if(!son[u]||sz[v]>sz[son[u]]) son[u]=v;
40     }
41 }
42 void dfs(int u,int tp){
43     top[u]=tp;
44     if(!son[u]) return;
45     dfs(son[u],tp);
46     for(int i=head[u];i;i=Next[i]){
47         int v=ver[i];
48         if(v==son[u]||v==fa[u]) continue;
49         dfs(v,v);
50     }
51 }
52 int LCA(int x,int y){
53     while(top[x]!=top[y])
54     d[top[x]]>=d[top[y]]?x=fa[top[x]]:y=fa[top[y]];
55     return d[x]>=d[y]?y:x;
56 }
57 int query(int ql,int qr,int lca,int lca_fa,int l,int r,int k){
58     if(l>=r) return l;
59     int x=sum[L[ql]]+sum[L[qr]]-sum[L[lca]]-sum[L[lca_fa]];
60     int mid=(l+r)>>1;
61     if(x>=k) return query(L[ql],L[qr],L[lca],L[lca_fa],l,mid,k);
62     else return query(R[ql],R[qr],R[lca],R[lca_fa],mid+1,r,k-x);
63 }
64 int main(){
65     //freopen("testdata.in","r",stdin);
66     n=read(),q=read();
67     for(int i=1;i<=n;++i)
68     b[i]=a[i]=read();
69     sort(b+1,b+1+n);
70     m=unique(b+1,b+1+n)-b-1;
71     for(int i=1;i<=n;++i)
72     a[i]=lower_bound(b+1,b+1+m,a[i])-b;
73     for(int i=1;i<n;++i){
74         int u=read(),v=read();
75         add(u,v);
76     }
77     dfs(1),dfs(1,1);
78     while(q--){
79         int x,y,z,lca;
80         x=read(),y=read(),z=read();
81         x^=ans,lca=LCA(x,y);
82         ans=b[query(rt[x],rt[y],rt[lca],rt[fa[lca]],1,m,z)];
83         printf("%d
",ans);
84     }
85     return 0;
86 }
View Code

 还有一题[bzoj3123][洛谷P3302] [SDOI2013]森林

路经查询就是主席树维护,而连接两棵树就是用启发式合并

题解

  1 //minamoto
  2 #include<bits/stdc++.h>
  3 using namespace std;
  4 inline int read(){
  5     #define num ch-'0'
  6     char ch;bool flag=0;int res;
  7     while(!isdigit(ch=getchar()))
  8     (ch=='-')&&(flag=true);
  9     for(res=num;isdigit(ch=getchar());res=res*10+num);
 10     (flag)&&(res=-res);
 11     #undef num
 12     return res;
 13 }
 14 const int N=80005,M=N*200;
 15 int ver[N<<2],Next[N<<2],head[N];
 16 int a[N],fa[N],sz[N],b[N];
 17 int n,m,tot,q,size,ans;
 18 void add(int u,int v){
 19     ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
 20     ver[++tot]=u,Next[tot]=head[v],head[v]=tot;
 21 }
 22 int L[M],R[M],sum[M],rt[N],cnt;
 23 void update(int last,int &now,int l,int r,int x){
 24     sum[now=++cnt]=sum[last]+1;
 25     if(l==r) return;
 26     int mid=(l+r)>>1;
 27     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
 28     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
 29 }
 30 int query(int u,int v,int lca,int lca_fa,int l,int r,int k){
 31     if(l>=r) return l;
 32     int x=sum[L[v]]+sum[L[u]]-sum[L[lca]]-sum[L[lca_fa]];
 33     int mid=(l+r)>>1;
 34     if(x>=k) return query(L[u],L[v],L[lca],L[lca_fa],l,mid,k);
 35     else return query(R[u],R[v],R[lca],R[lca_fa],mid+1,r,k-x);
 36 }
 37 inline int hash(int x){
 38     return lower_bound(b+1,b+1+size,x)-b;
 39 }
 40 int ff(int x){
 41     return fa[x]==x?x:fa[x]=ff(fa[x]);
 42 }
 43 int st[N][17],d[N],vis[N];
 44 void dfs(int u,int father,int root){
 45     st[u][0]=father;
 46     for(int i=1;i<=16;++i)
 47     st[u][i]=st[st[u][i-1]][i-1];
 48     ++sz[root];
 49     d[u]=d[father]+1;
 50     fa[u]=root;
 51     vis[u]=1;
 52     update(rt[father],rt[u],1,size,hash(a[u]));
 53     for(int i=head[u];i;i=Next[i]){
 54         int v=ver[i];
 55         if(v==father) continue;
 56         dfs(v,u,root);
 57     }
 58 }
 59 int LCA(int x,int y){
 60     if(x==y) return x;
 61     if(d[x]<d[y]) swap(x,y);
 62     for(int i=16;i>=0;--i){
 63         if(d[st[x][i]]>=d[y]) x=st[x][i];
 64     }
 65     if(x==y) return x;
 66     for(int i=16;i>=0;--i){
 67         if(st[x][i]!=st[y][i])
 68         x=st[x][i],y=st[y][i];
 69     }
 70     return st[x][0];
 71 }
 72 int main(){
 73     //freopen("testdata.in","r",stdin);
 74     int t=read();
 75     n=read(),m=read(),q=read();
 76     for(int i=1;i<=n;++i)
 77     a[i]=b[i]=read(),fa[i]=i;
 78     sort(b+1,b+1+n);
 79     size=unique(b+1,b+1+n)-b-1;
 80     for(int i=1;i<=m;++i){
 81         int u=read(),v=read();
 82         add(u,v);
 83     }
 84     for(int i=1;i<=n;++i)
 85     if(!vis[i]) dfs(i,0,i);
 86     while(q--){
 87         char ch;int x,y;
 88         while(!isupper(ch=getchar()));
 89         x=read()^ans,y=read()^ans;
 90         if(ch=='Q'){
 91             int k=read()^ans;
 92             int lca=LCA(x,y);
 93             ans=b[query(rt[x],rt[y],rt[lca],rt[st[lca][0]],1,size,k)];
 94             printf("%d
",ans);
 95         }
 96         else{
 97             add(x,y);
 98             int u=ff(x),v=ff(y);
 99             if(sz[u]<sz[v]) swap(x,y),swap(u,v);
100             dfs(y,x,u);
101         }
102     }
103     return 0;
104 }
View Code

洛谷P3066 [USACO12DEC]逃跑的BarnRunning Away From…

要对每一个子树进行操作,怎么做呢?

我们可以直接dfs这棵树,并记录下进入一个点的编号$l[i]$和从这个点出去时的编号$r[i]$

那么这个点的子树的区间一定是$[l[i],r[i]]$

然后直接在树上查询就行了

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 200005
 4 #define M 4000005
 5 #define ll long long
 6 #define inf 0x3f3f3f3f
 7 using namespace std;
 8 inline ll read(){
 9     #define num ch-'0'
10     char ch;bool flag=0;ll res;
11     while(!isdigit(ch=getchar()))
12     (ch=='-')&&(flag=true);
13     for(res=num;isdigit(ch=getchar());res=res*10+num);
14     (flag)&&(res=-res);
15     #undef num
16     return res;
17 }
18 int sum[M],L[M],R[M],rt[N];
19 int ver[N<<1],Next[N<<1],head[N];ll edge[N<<1];
20 int ls[N],rs[N];ll a[N],b[N];
21 int n,m,cnt,tot;ll p;
22 void update(int last,int &now,int l,int r,int x){
23     sum[now=++cnt]=sum[last]+1;
24     if(l==r) return;
25     int mid=(l+r)>>1;
26     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
27     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
28 }
29 int query(int u,int v,int l,int r,int k){
30     if(r<k) return sum[v]-sum[u];
31     if(l>=k) return 0;
32     int mid=(l+r)>>1;
33     if(k<=mid) return query(L[u],L[v],l,mid,k);
34     else return query(R[u],R[v],mid+1,r,k)+sum[L[v]]-sum[L[u]];
35 }
36 inline void add(int u,int v,ll e){
37     ver[++tot]=v,Next[tot]=head[u],head[u]=tot,edge[tot]=e;
38 }
39 void dfs(int u,int fa,ll d){
40     b[ls[u]=++m]=d,a[m]=d;
41     for(int i=head[u];i;i=Next[i])
42     if(ver[i]!=fa) dfs(ver[i],u,d+edge[i]);
43     rs[u]=m;
44 }
45 int main(){
46     n=read(),p=read();
47     for(int u=2;u<=n;++u){
48         int v=read();ll e=read();
49         add(v,u,e);
50     }
51     dfs(1,0,0);
52     sort(b+1,b+1+m);
53     m=unique(b+1,b+1+m)-b-1;
54     for(int i=1;i<=n;++i){
55         int k=lower_bound(b+1,b+1+m,a[i])-b;
56         update(rt[i-1],rt[i],1,m,k);
57     }
58     b[m+1]=inf;
59     for(int i=1;i<=n;++i){
60         int k=upper_bound(b+1,b+2+m,a[ls[i]]+p)-b;
61         k=query(rt[ls[i]-1],rt[rs[i]],1,m,k);
62         printf("%d
",k);
63     }
64     return 0;
65 }
View Code

bzoj 1803: Spoj1487 Query on a tree III(主席树)。基础的树上查询

题解

 1 //minamoto
 2 #include<iostream>
 3 #include<cstdio>
 4 #include<algorithm>
 5 using namespace std;
 6 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 7 char buf[1<<21],*p1=buf,*p2=buf;
 8 inline int read(){
 9     #define num ch-'0'
10     char ch;bool flag=0;int res;
11     while(!isdigit(ch=getc()))
12     (ch=='-')&&(flag=true);
13     for(res=num;isdigit(ch=getc());res=res*10+num);
14     (flag)&&(res=-res);
15     #undef num
16     return res;
17 }
18 char obuf[1<<24],*o=obuf;
19 inline void print(int x){
20     if(x>9) print(x/10);
21     *o++=x%10+48;
22 }
23 const int N=100005,M=N*30;
24 int sum[M],L[M],R[M],rt[N];
25 int ver[N<<1],Next[N<<1],head[N];
26 int ls[N],rs[N],a[N],b[N],id[N],pos[N];
27 int n,m,cnt,tot,q;
28 void update(int last,int &now,int l,int r,int x){
29     sum[now=++cnt]=sum[last]+1;
30     if(l==r) return;
31     int mid=(l+r)>>1;
32     if(x<=mid) R[now]=R[last],update(L[last],L[now],l,mid,x);
33     else L[now]=L[last],update(R[last],R[now],mid+1,r,x);
34 }
35 int query(int u,int v,int l,int r,int k){
36     if(l>=r) return l;
37     int x=sum[L[v]]-sum[L[u]];
38     int mid=(l+r)>>1;
39     if(x>=k) return query(L[u],L[v],l,mid,k);
40     else return query(R[u],R[v],mid+1,r,k-x);
41 }
42 inline void add(int u,int v){
43     ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
44     ver[++tot]=u,Next[tot]=head[v],head[v]=tot;
45 }
46 void dfs(int u,int fa){
47     a[ls[u]=++m]=b[u],id[m]=u;
48     for(int i=head[u];i;i=Next[i])
49     if(ver[i]!=fa) dfs(ver[i],u);
50     rs[u]=m;
51 }
52 int main(){
53     //freopen("testdata.in","r",stdin);
54     n=read();
55     for(int i=1;i<=n;++i) b[i]=read();
56     for(int i=1;i<n;++i){
57         int u,v;
58         u=read(),v=read();
59         add(u,v);
60     }
61     dfs(1,0);
62     sort(b+1,b+1+m);
63     for(int i=1;i<=n;++i){
64         int k=lower_bound(b+1,b+1+m,a[i])-b;
65         update(rt[i-1],rt[i],1,m,k);
66         pos[k]=id[i];
67     }
68     q=read();
69     while(q--){
70         int u=read(),k=read();
71         int ans=pos[query(rt[ls[u]-1],rt[rs[u]],1,m,k)];
72         print(ans),*o++='
';
73     }
74     fwrite(obuf,o-obuf,1,stdout);
75     return 0;
76 }
View Code

带修改主席树

我们可以发现,主席树每一棵线段树维护的都是一个前缀和

如果有修改操作,每一次都要对后面的所有的前缀和都进行修改,那样的话时间复杂度就太爆炸了

我们可以考虑一下树状数组

树状数组维护的也是前缀和,但它的每一次修改是$O(log n)$的

他的节点存的并不是前缀和,但我们仍可以用树状数组来求出前缀和

于是我们可以用树状数组的思想来维护,主席树

用树状数组存一下每个节点的位置,每一次修改都按树状数组的方法去修改,也就是说并不需要修改那么多节点

查询的时候,也按树状数组的方法查询就好了

建议对这段话仔细理解,我当初也是懵逼了好久,最后看了zcysky大佬的那篇blog才蓦然醒悟的

拿bzoj1901洛谷P2617 Dynamic Rankings为例

是一个带修改主席树的板子

思路就按我上面所说的

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 10005
 4 using namespace std;
 5 inline int read(){
 6     #define num ch-'0'
 7     char ch;bool flag=0;int res;
 8     while(!isdigit(ch=getchar()))
 9     (ch=='-')&&(flag=true);
10     for(res=num;isdigit(ch=getchar());res=res*10+num);
11     (flag)&&(res=-res);
12     #undef num
13     return res;
14 }
15 inline int lowbit(int x){return x&(-x);}
16 int sum[N*600],L[N*600],R[N*600];
17 int xx[N],yy[N],rt[N],a[N],b[N<<1],ca[N],cb[N],cc[N];
18 int n,q,m,cnt=0,totx,toty;
19 void update(int last,int &now,int l,int r,int x,int v){
20     sum[now=++cnt]=sum[last]+v;
21     L[now]=L[last],R[now]=R[last];
22     if(l==r) return;
23     int mid=(l+r)>>1;
24     if(x<=mid) update(L[last],L[now],l,mid,x,v);
25     else update(R[last],R[now],mid+1,r,x,v);
26 }
27 int query(int l,int r,int q){
28     if(l==r) return l;
29     int x=0,mid=(l+r)>>1;
30     for(int i=1;i<=totx;++i) x-=sum[L[xx[i]]];
31     for(int i=1;i<=toty;++i) x+=sum[L[yy[i]]];
32     if(q<=x){
33         for(int i=1;i<=totx;++i) xx[i]=L[xx[i]];
34         for(int i=1;i<=toty;++i) yy[i]=L[yy[i]];
35         return query(l,mid,q);
36     }
37     else{
38         for(int i=1;i<=totx;++i) xx[i]=R[xx[i]];
39         for(int i=1;i<=toty;++i) yy[i]=R[yy[i]];
40         return query(mid+1,r,q-x);
41     }
42 }
43 void add(int x,int y){
44     int k=lower_bound(b+1,b+1+m,a[x])-b;
45     for(int i=x;i<=n;i+=lowbit(i)) update(rt[i],rt[i],1,m,k,y);
46 }
47 int main(){
48     //freopen("testdata.in","r",stdin);
49     n=read(),q=read();
50     for(int i=1;i<=n;++i)
51     b[++m]=a[i]=read();
52     for(int i=1;i<=q;++i){
53         char ch;
54         while(!isupper(ch=getchar()));
55         ca[i]=read(),cb[i]=read();
56         if(ch=='Q') cc[i]=read();else b[++m]=cb[i];
57     }
58     sort(b+1,b+1+m);
59     m=unique(b+1,b+1+m)-b-1;
60     for(int i=1;i<=n;++i) add(i,1);
61     for(int i=1;i<=q;++i){
62         if(cc[i]){
63             totx=toty=0;
64             for(int j=ca[i]-1;j;j-=lowbit(j)) xx[++totx]=rt[j];
65             for(int j=cb[i];j;j-=lowbit(j)) yy[++toty]=rt[j];
66             printf("%d
",b[query(1,m,cc[i])]);
67         }
68         else{add(ca[i],-1),a[ca[i]]=cb[i],add(ca[i],1);}
69     }
70     return 0;
71 }
View Code

还有一道[BZOJ3295] [Cqoi2011]洛谷p3157动态逆序对

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define N 100005
 4 #define M 5000005
 5 #define ll long long
 6 using namespace std;
 7 inline ll read(){
 8     #define num ch-'0'
 9     char ch;bool flag=0;ll res;
10     while(!isdigit(ch=getchar()))
11     (ch=='-')&&(flag=true);
12     for(res=num;isdigit(ch=getchar());res=res*10+num);
13     (flag)&&(res=-res);
14     #undef num
15     return res;
16 }
17 int L[M],R[M],sum[M],rt[N];
18 int val[N],pos[N],xx[N],yy[N],c[N],a1[N],a2[N];
19 int n,cnt,q;ll ans=0;
20 inline int lowbit(int x){return x&(-x);}
21 int ask(int x){
22     int s=0;
23     for(int i=x;i;i-=lowbit(i)) s+=c[i];
24     return s;
25 }
26 void update(int &now,int l,int r,int k){
27     if(!now) now=++cnt;
28     ++sum[now];
29     if(l==r) return;
30     int mid=(l+r)>>1;
31     if(k<=mid) update(L[now],l,mid,k);
32     else update(R[now],mid+1,r,k);
33 }
34 int querysub(int x,int y,int v){
35     int cntx=0,cnty=0,ans=0;--x;
36     for(int i=x;i;i-=lowbit(i)) xx[++cntx]=rt[i];
37     for(int i=y;i;i-=lowbit(i)) yy[++cnty]=rt[i];
38     int l=1,r=n;
39     while(l<r){
40         int mid=(l+r)>>1;
41         if(v<=mid){
42             for(int i=1;i<=cntx;++i) ans-=sum[R[xx[i]]];
43             for(int i=1;i<=cnty;++i) ans+=sum[R[yy[i]]];
44             for(int i=1;i<=cntx;++i) xx[i]=L[xx[i]];
45             for(int i=1;i<=cnty;++i) yy[i]=L[yy[i]];
46             r=mid;
47         }
48         else{
49             for(int i=1;i<=cntx;++i) xx[i]=R[xx[i]];
50             for(int i=1;i<=cnty;++i) yy[i]=R[yy[i]];
51             l=mid+1;
52         }
53     }
54     return ans;
55 }
56 int querypre(int x,int y,int v){
57     int cntx=0,cnty=0,ans=0;--x;
58     for(int i=x;i;i-=lowbit(i)) xx[++cntx]=rt[i];
59     for(int i=y;i;i-=lowbit(i)) yy[++cnty]=rt[i];
60     int l=1,r=n;
61     while(l<r){
62         int mid=(l+r)>>1;
63         if(v>mid){
64             for(int i=1;i<=cntx;++i) ans-=sum[L[xx[i]]];
65             for(int i=1;i<=cnty;++i) ans+=sum[L[yy[i]]];
66             for(int i=1;i<=cntx;++i) xx[i]=R[xx[i]];
67             for(int i=1;i<=cnty;++i) yy[i]=R[yy[i]];
68             l=mid+1;
69         }
70         else{
71             for(int i=1;i<=cntx;++i) xx[i]=L[xx[i]];
72             for(int i=1;i<=cnty;++i) yy[i]=L[yy[i]];
73             r=mid;
74         }
75     }
76     return ans;
77 }
78 int main(){
79     //freopen("testdata.in","r",stdin);
80     n=read(),q=read();
81     for(int i=1;i<=n;++i){
82         val[i]=read(),pos[val[i]]=i;
83         a1[i]=ask(n)-ask(val[i]);
84         ans+=a1[i];
85         for(int j=val[i];j<=n;j+=lowbit(j)) ++c[j];
86     }
87     memset(c,0,sizeof(c));
88     for(int i=n;i;--i){
89         a2[i]=ask(val[i]-1);
90         for(int j=val[i];j<=n;j+=lowbit(j)) ++c[j];
91     }
92     while(q--){
93         printf("%lld
",ans);
94         int x=read();x=pos[x];
95         ans-=(a1[x]+a2[x]-querysub(1,x-1,val[x])-querypre(x+1,n,val[x]));
96         for(int j=x;j<=n;j+=lowbit(j)) update(rt[j],1,n,val[x]);
97     }
98     return 0;
99 }
View Code

 进阶

个人认为主席树的一些好题

【bzoj2653】【middle】

可以加深对主席树的应用,不再只会求第k大之类的套路

题解

 1 //minamoto
 2 #include<iostream>
 3 #include<cstdio>
 4 #include<algorithm>
 5 using namespace std;
 6 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 7 char buf[1<<21],*p1=buf,*p2=buf;
 8 inline int read(){
 9     #define num ch-'0'
10     char ch;bool flag=0;int res;
11     while(!isdigit(ch=getc()))
12     (ch=='-')&&(flag=true);
13     for(res=num;isdigit(ch=getc());res=res*10+num);
14     (flag)&&(res=-res);
15     #undef num
16     return res;
17 }
18 char obuf[1<<24],*o=obuf;
19 inline void print(int x){
20     if(x>9) print(x/10);
21     *o++=x%10+48;
22 }
23 const int N=20005,M=N*30;
24 int n,Pre,q,cnt;
25 int rt[N],p[5];
26 struct node{
27     int l,r,lmx,rmx,sum;
28 }t[M],op;
29 struct data{
30     int x,id;
31     inline bool operator <(const data &b)const
32     {return x<b.x;}
33 }a[N];
34 inline void pushup(int x){
35     t[x].sum=t[t[x].l].sum+t[t[x].r].sum;
36     t[x].lmx=max(t[t[x].l].lmx,t[t[x].l].sum+t[t[x].r].lmx);
37     t[x].rmx=max(t[t[x].r].rmx,t[t[x].r].sum+t[t[x].l].rmx);
38 }
39 void build(int &now,int l,int r){
40     now=++cnt;
41     if(l==r){t[now].lmx=t[now].rmx=t[now].sum=1;return;}
42     int mid=(l+r)>>1;
43     build(t[now].l,l,mid);
44     build(t[now].r,mid+1,r);
45     pushup(now);
46 }
47 void update(int last,int &now,int l,int r,int k){
48     now=++cnt;
49     if(l==r){t[now].lmx=t[now].rmx=t[now].sum=-1;return;}
50     int mid=(l+r)>>1;
51     if(k<=mid) t[now].r=t[last].r,update(t[last].l,t[now].l,l,mid,k);
52     else t[now].l=t[last].l,update(t[last].r,t[now].r,mid+1,r,k);
53     pushup(now);
54 }
55 node merge(node x,node y){
56     node z;
57     z.sum=x.sum+y.sum;
58     z.lmx=max(x.lmx,x.sum+y.lmx);
59     z.rmx=max(y.rmx,y.sum+x.rmx);
60     return z;
61 }
62 node find(int x,int l,int r,int y,int z){
63     if(y>z) return op;
64     if(l==y&&r==z) return t[x];
65     int mid=(l+r)>>1;
66     if(z<=mid) return find(t[x].l,l,mid,y,z);
67     else if(y>mid) return find(t[x].r,mid+1,r,y,z);
68     else return merge(find(t[x].l,l,mid,y,mid),find(t[x].r,mid+1,r,mid+1,z));
69 }
70 int query(int x){
71     return find(rt[x],1,n,p[1],p[2]).rmx+find(rt[x],1,n,p[2]+1,p[3]-1).sum+find(rt[x],1,n,p[3],p[4]).lmx;
72 }
73 int main(){
74     //freopen("testdata.in","r",stdin);
75     n=read();
76     for(int i=1;i<=n;++i) a[i].x=read(),a[i].id=i;
77     sort(a+1,a+1+n);
78     build(rt[1],1,n);
79     for(int i=2;i<=n;++i) update(rt[i-1],rt[i],1,n,a[i-1].id);
80     q=read();
81     while(q--){
82         int x=read(),y=read(),z=read(),k=read();
83         p[1]=(x+Pre)%n+1,p[2]=(y+Pre)%n+1,p[3]=(z+Pre)%n+1,p[4]=(k+Pre)%n+1;
84         sort(p+1,p+5);
85         int l=1,r=n,ans=1;
86         while(l<=r){
87             int mid=(l+r)>>1;
88             int f=query(mid);
89             if(f>=0) ans=mid,l=mid+1;
90             else r=mid-1;
91         }
92         Pre=a[ans].x;
93         print(a[ans].x),*o++='
';
94     }
95     fwrite(obuf,o-obuf,1,stdout);
96     return 0;
97 }
View Code

hdu 4348 To the moon

主席树的区间修改,应该算是真正的可持久化?

题解

 1 //minamoto
 2 #include<bits/stdc++.h>
 3 #define ll long long
 4 using namespace std;
 5 const int N=100005,M=N*30;
 6 int n,m,cnt,rt[N];
 7 int L[M],R[M];ll sum[M],add[M];
 8 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
 9 char buf[1<<21],*p1=buf,*p2=buf;
10 inline ll read(){
11     #define num ch-'0'
12     char ch;bool flag=0;ll res;
13     while(!isdigit(ch=getc()))
14     (ch=='-')&&(flag=true);
15     for(res=num;isdigit(ch=getc());res=res*10+num);
16     (flag)&&(res=-res);
17     #undef num
18     return res;
19 }
20 void build(int &now,int l,int r){
21     add[now=++cnt]=0;
22     if(l==r) return (void)(sum[now]=read());
23     int mid=(l+r)>>1;
24     build(L[now],l,mid);
25     build(R[now],mid+1,r);
26     sum[now]=sum[L[now]]+sum[R[now]];
27 }
28 void update(int last,int &now,int l,int r,int ql,int qr,int x){
29     now=++cnt;
30     L[now]=L[last],R[now]=R[last],add[now]=add[last],sum[now]=sum[last];
31     sum[now]+=1ll*x*(qr-ql+1);
32     if(ql==l&&qr==r) return (void)(add[now]+=x);
33     int mid=(l+r)>>1;
34     if(qr<=mid) update(L[last],L[now],l,mid,ql,qr,x);
35     else if(ql>mid) update(R[last],R[now],mid+1,r,ql,qr,x);
36     else return (void)(update(L[last],L[now],l,mid,ql,mid,x),update(R[last],R[now],mid+1,r,mid+1,qr,x));
37 }
38 ll query(int now,int l,int r,int ql,int qr){
39     if(l==ql&&r==qr) return sum[now];
40     int mid=(l+r)>>1;
41     ll res=1ll*add[now]*(qr-ql+1);
42     if(qr<=mid) res+=query(L[now],l,mid,ql,qr);
43     else if(ql>mid) res+=query(R[now],mid+1,r,ql,qr);
44     else res+=query(L[now],l,mid,ql,mid)+query(R[now],mid+1,r,mid+1,qr);
45     return res;
46 }
47 int main(){
48     //freopen("testdata.in","r",stdin);
49     n=read(),m=read();
50     cnt=-1;
51     build(rt[0],1,n);
52     int now=0;
53     while(m--){
54         char ch;int l,r,x;
55         while(!isupper(ch=getc()));
56         switch(ch){
57             case 'C':{
58                 l=read(),r=read(),x=read();
59                 ++now;
60                 update(rt[now-1],rt[now],1,n,l,r,x);
61                 break;
62             }
63             case 'Q':{
64                 l=read(),r=read();
65                 printf("%lld
",query(rt[now],1,n,l,r));
66                 break;
67             }
68             case 'H':{
69                 l=read(),r=read(),x=read();
70                 printf("%lld
",query(rt[x],1,n,l,r));
71                 break;
72             }
73             case 'B':{
74                 now=read();
75                 cnt=rt[now+1]-1;
76                 break;
77             }
78         }
79     }
80     return 0;
81 }
View Code

鉴于本人十分弱鸡,可能讲的不是非常清楚,欢迎大家在下面补充

原文地址:https://www.cnblogs.com/bztMinamoto/p/9398329.html