主席树学习笔记

主席树是可持久化线段树。

什么叫可持久化?就是能维护历史版本。

如何维护历史版本?

你会发现修改操作最多只修改logn个节点。

那么你另外建logn个结点表示修改后的线段树信息,其余不变,而root[i]指向第i个历史版本的线段树的根节点。

时空复杂度均为O(nlogn)

查询操作呢?

以静态区间第k小(大)为例。

建树时,相当从空线段树开始,修改n个结点。求[l, r]区间第k大时,用第r个版本的线段树与第l-1个版本的线段树作差即可。

详情可见代码。

主席树ppt

http://seter.is-programmer.com/posts/31907.html

https://blog.finaltheory.me/algorithm/Chairman-Tree.html#content-heading

HDU2665

题意:n个数,多个询问。每次询问[l, r]区间内第k大的数。

题解:离散化。相当于有n个操作。询问[l, r]内第k大的数就相当于询问哪个数在r时刻的sum比l时刻的sum大k.注意query()函数的写法。类似线段树上无二分的全局第k大。

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 const int N = 1e5+5;
 4 int n, m, cnt, root[N], a[N], x, y, k;
 5 vector<int> ve;
 6 struct node{
 7     int l, r, sum;
 8 };
 9 node T[N*20];
10 
11 void update(int& x, int y, int pos, int l, int r){
12     T[++cnt] = T[y], T[cnt].sum++, x = cnt;//新建一个拷贝自y节点的点, sum++表示插入pos, 令x指向该节点。
13     if(l == r) return ;
14     int m = l+r >> 1;
15     if(pos <= m) update(T[x].l, T[y].l, pos, l, m);
16     else update(T[x].r, T[y].r, pos, m+1, r);
17 }
18 int query(int x, int y, int k, int l, int r){
19     if(l == r) return l;
20     int m = l+r >> 1;
21     int sum = T[ T[y].l ].sum-T[ T[x].l ].sum;
22     if(sum >= k) return query(T[x].l, T[y].l, k, l, m);
23     else return query(T[x].r, T[y].r, k-sum, m+1, r);
24 }
25 
26 int main(){
27     int t; scanf("%d", &t);
28     while(t--){
29         ve.clear();
30         scanf("%d%d", &n, &m);
31         for(int i = 1; i <= n; i++){
32             scanf("%d", a+i);
33             ve.push_back(a[i]);
34         }
35         sort(ve.begin(), ve.end());
36         ve.erase( unique(ve.begin(), ve.end()), ve.end() );
37 
38         cnt = 0;
39         for(int i = 1; i <= n; i++) {
40             int kth = lower_bound(ve.begin(), ve.end(), a[i])-ve.begin()+1;
41             update(root[i], root[i-1], kth, 1, n);
42         }
43         for(int i = 1; i <= m; i++){
44             scanf("%d%d%d", &x, &y, &k);
45             printf("%d
", ve[ query(root[x-1], root[y], k, 1, n)-1 ]);
46         }
47     }
48     return 0;
49 }
View Code

spoj 3267

题意:给出一个数组。求区间内不同的数的个数。

题解:离散化。相当于有n个操作。相同数字中只保存最右的数字。那么询问[l, r]内不同的数的个数就相当于询问r时刻[l, r]的sum值,也就是r时刻[l, n]的sum值(r以后都是0).

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 const int N = 3e4+5;
 5 int last[N];
 6 int n, m, cnt, root[N], a[N], x, y, k;
 7 vector<int> ve;
 8 struct node{
 9     int l, r, sum;
10 };
11 node T[N*20];
12 
13 void update(int& x, int y, int pos, int add, int l, int r){
14     T[++cnt] = T[y], T[cnt].sum += add, x = cnt;//新建一个拷贝自y节点的点, sum++/--表示插入/删除pos, 令x指向该节点。
15     if(l == r) return ;
16     int m = l+r >> 1;
17     if(pos <= m) update(T[x].l, T[y].l, pos, add, l, m);
18     else update(T[x].r, T[y].r, pos, add, m+1, r);
19 }
20 int query(int x, int y, int l, int r){
21     if(x <= l) return T[y].sum;
22     if(x > r) return 0;
23     int m = l+r >> 1;
24     return query(x, T[y].l, l, m)+query(x, T[y].r, m+1, r);
25 }
26 
27 int main(){
28     scanf("%d", &n);
29     ve.clear();
30     for(int i = 1; i <= n; i++){
31         scanf("%d", a+i);
32         ve.push_back(a[i]);
33     }
34     sort(ve.begin(), ve.end());
35     ve.erase( unique(ve.begin(), ve.end()), ve.end() );
36     for(int i = 1; i <= n; i++)
37         a[i] = lower_bound(ve.begin(), ve.end(), a[i])-ve.begin()+1;
38     memset(last, 0, sizeof(last));
39 
40     cnt = 0;
41     for(int i = 1; i <= n; i++){
42         int num = a[i];
43         if(last[num])
44             update(root[i], root[i-1], last[num], -1, 1, n),
45             update(root[i], root[i], i, 1, 1, n);//
46         else
47             update(root[i], root[i-1], i, 1, 1, n);//
48         last[num] = i;
49     }
50     int q;
51     scanf("%d", &q);
52     while(q--){
53         int l, r;
54         scanf("%d%d", &l, &r);
55         printf("%d
", query(l, root[r], 1, n));
56     }
57     return 0;
58 }
View Code

bzoj1901 

题意:HDU2665升级版。支持修改操作。修改第i个数位x,查询[l, r]内第k大数。

题解:先按静态建好主席树。另加树状数组套主席树维护修改的值(也就是说,本来修改第i个数需要同时修改根为i ~ n的主席树,O(n)个主席树,套树状数组后只需修改O(logn)个主席树。)

一开始RE,后来想想静态主席树是nlogn个节点,修改的时候需要新增mlognlogn个节点,数组比平时要多一个log,那自然就RE了~

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 //nlogn + mlognlogn
 4 const int N = 2e4+5;
 5 int n, m, cnt, root[N], a[N], tot;
 6 vector<int> ve, p, q;
 7 struct node{
 8     int l, r, sum;
 9 };
10 node T[N*200];
11 
12 void update(int& x, int y, int pos, int add, int l, int r){
13     T[++cnt] = T[y], T[cnt].sum += add, x = cnt;//新建一个拷贝自y节点的点, sum++/--表示插入/删除pos, 令x指向该节点。
14     if(l == r) return ;
15     int m = l+r >> 1;
16     if(pos <= m) update(T[x].l, T[y].l, pos, add, l, m);
17     else update(T[x].r, T[y].r, pos, add, m+1, r);
18 }
19 int query(int k, int l, int r){
20     if(l == r) return l;
21     int m = l+r >> 1, cnt1 = 0, cnt2 = 0;
22     for(int i = 0; i < p.size(); i++) cnt1 += T[ T[p[i]].l ].sum;
23     for(int i = 0; i < q.size(); i++) cnt2 += T[ T[q[i]].l ].sum;
24     if(cnt2-cnt1 >= k){
25         for(int i = 0; i < p.size(); i++) p[i] = T[p[i]].l;
26         for(int i = 0; i < q.size(); i++) q[i] = T[q[i]].l;
27         return query(k, l, m);
28     }
29     else{
30         for(int i = 0; i < p.size(); i++) p[i] = T[p[i]].r;
31         for(int i = 0; i < q.size(); i++) q[i] = T[q[i]].r;
32         return query(k-(cnt2-cnt1), m+1, r);
33     }
34 }
35 
36 int lowbit(int x){ return x&-x;}
37 int add(int x, int pos, int add){
38     for(int i = x; i <= n; i += lowbit(i))
39         update(root[i], root[i], pos, add, 1, tot);
40 }
41 int solve(int l, int r, int k){
42     l--;
43     p.clear(), q.clear();
44     if(l) p.push_back(root[n+l]);
45     q.push_back(root[n+r]);
46     for(int i = l; i; i -= lowbit(i))
47         p.push_back(root[i]);
48     for(int i = r; i; i -= lowbit(i))
49         q.push_back(root[i]);
50     return query(k, 1, tot);
51 }
52 
53 char op[N];
54 int L[N], R[N], K[N];
55 int main(){
56     scanf("%d%d", &n, &m);
57     ve.clear();
58     for(int i = 1; i <= n; i++){
59         scanf("%d", a+i);
60         ve.push_back(a[i]);
61     }
62 
63     for(int i = 1; i <= m; i++){
64         scanf(" %c", op+i);
65         if(op[i] == 'Q') scanf("%d%d%d", L+i, R+i, K+i);
66         else scanf("%d%d", L+i, R+i), ve.push_back(R[i]);
67     }
68 
69     sort(ve.begin(), ve.end());
70     ve.erase( unique(ve.begin(), ve.end()), ve.end() );
71 
72     tot = ve.size();
73     cnt = 0;
74 
75     for(int i = 1; i <= n; i++){
76         int pos = lower_bound(ve.begin(), ve.end(), a[i])-ve.begin()+1;
77         update(root[n+i], root[n+i-1], pos, 1, 1, tot);
78     }
79 
80     for(int i = 1; i <= m; i++){
81         if(op[i] == 'Q'){
82             int val = solve(L[i], R[i], K[i]);
83             printf("%d
", ve[val-1]);
84         }
85         else {
86             int val = lower_bound(ve.begin(), ve.end(), a[L[i]])-ve.begin()+1;
87             add(L[i], val, -1);
88             a[ L[i] ] = R[i];
89             val = lower_bound(ve.begin(), ve.end(), a[L[i]])-ve.begin()+1;
90             add(L[i], val, 1);
91         }
92     }
93     return 0;
94 }
View Code

hdu4417

题意:求区间内比k小的数的个数。

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 const int N = 1e5+5;
 4 int n, m, cnt, root[N], a[N], tot;
 5 vector<int> ve, p, q;
 6 struct node{
 7     int l, r, sum;
 8 };
 9 node T[N*20];
10 
11 void update(int& x, int y, int pos, int add, int l, int r){
12     T[++cnt] = T[y], T[cnt].sum += add, x = cnt;
13     if(l == r) return ;
14     int m = l+r >> 1;
15     if(pos <= m) update(T[x].l, T[y].l, pos, add, l, m);
16     else update(T[x].r, T[y].r, pos, add, m+1, r);
17 }
18 int query(int x, int y, int h, int l, int r){
19     if(l == r) return T[y].sum-T[x].sum;
20     int m = l+r >> 1;
21     if(h <= m) return query(T[x].l, T[y].l, h, l, m);
22     else return T[ T[y].l ].sum-T[ T[x].l ].sum+query(T[x].r, T[y].r, h, m+1, r);
23 }
24 
25 int main(){
26     int t, ca = 1; scanf("%d", &t);
27     while(t--){
28         printf("Case %d:
", ca++);
29         scanf("%d%d", &n, &m);
30         ve.clear();
31         for(int i = 1; i <= n; i++){
32             scanf("%d", a+i);
33             ve.push_back(a[i]);
34         }
35         sort(ve.begin(), ve.end());
36         ve.erase( unique(ve.begin(), ve.end()), ve.end() );
37         tot = ve.size();
38 
39         for(int i = 1; i <= n; i++){
40             int pos = lower_bound(ve.begin(), ve.end(), a[i])-ve.begin()+1;
41             update(root[i], root[i-1], pos, 1, 1, tot);
42         }
43 
44         int l, r, h;
45         for(int i = 1; i <= m; i++){
46             scanf("%d%d%d", &l, &r, &h);
47             l++, r++;
48             h = upper_bound(ve.begin(), ve.end(), h)-ve.begin();
49             printf("%d
", h? query(root[l-1], root[r], h, 1, tot):0);
50         }
51     }
52     return 0;
53 }
View Code

hdu5919

题意:置。kk/2少。

区间第k大,区间不同数都是主席树套路。

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 const int N = 2e5+5;
 4 int n, m, cnt, root[N], a[N];
 5 int last[N];
 6 vector<int> ve, p, q;
 7 struct node{
 8     int l, r, sum;
 9 };
10 node T[N*40];
11 
12 void update(int& x, int y, int pos, int add, int l, int r){
13     T[++cnt] = T[y], T[cnt].sum += add, x = cnt;//新建一个拷贝自y节点的点, sum++/--表示插入/删除pos, 令x指向该节点。
14     if(l == r) return ;
15     int m = l+r >> 1;
16     if(pos <= m) update(T[x].l, T[y].l, pos, add, l, m);
17     else update(T[x].r, T[y].r, pos, add, m+1, r);
18 }
19 int query(int x, int y, int l, int r){//x节点的根树中下标 >= y的个数
20     if(l == r) return T[x].sum;
21     int m = l+r >> 1;
22     if(y <= m) return query(T[x].l, y, l, m)+T[T[x].r].sum;
23     else return query(T[x].r, y, m+1, r);
24 }
25 int querypos(int x, int y, int l, int r){
26     if(l == r) return l;
27     int m = l+r >> 1;
28     if(T[T[x].l].sum >= y) return querypos(T[x].l, y, l, m);
29     else return querypos(T[x].r, y-T[T[x].l].sum, m+1, r);
30 }
31 
32 int main(){
33     int t, ca = 1; scanf("%d", &t);
34     while(t--){
35         scanf("%d%d", &n, &m);
36         ve.clear();
37         for(int i = 1; i <= n; i++){
38             scanf("%d", a+i);
39             ve.push_back(a[i]);
40         }
41         sort(ve.begin(), ve.end());
42         ve.erase( unique(ve.begin(), ve.end()), ve.end() );
43 
44         memset(last, 0, sizeof(last));
45         cnt = 0, root[n+1] = 0;
46         for(int i = n; i; i--){
47             int num = lower_bound(ve.begin(), ve.end(), a[i])-ve.begin()+1;
48             if(last[num])
49                 update(root[i], root[i+1], last[num], -1, 1, n), update(root[i], root[i], i, 1, 1, n);
50             else
51                 update(root[i], root[i+1], i, 1, 1, n);
52             last[num] = i;
53         }
54         printf("Case #%d: ", ca++);
55         int l, r, ans = 0;
56         for(int i = 1; i <= m; i++){
57             scanf("%d%d", &l, &r);
58             l = (l+ans)%n+1, r = (r+ans)%n+1;
59             if(l > r) swap(l, r);
60             int sum = T[ root[l] ].sum, rsum = query(root[l], r+1, 1, n);
61             int pos = (sum-rsum+1)/2;
62             ans = querypos(root[l], pos, 1, n);
63             printf("%d%c", ans, " 
"[i == m]);
64         }
65     }
66     return 0;
67 }
View Code

hdu4605

题意:有一棵树二叉树,每个节点有一个数字。每次询问一个节点和一个带数字的球,问这个球经过这个节点的概率。一个带数字的球从根开始往下走,如果走过的节点数字和他相同,就停在这里;如果节点数字比他大,向左向右的概率都是1/2;否则向左的概率是1/8,向右的概率是7/8。

题解:主席树维护每个结点到根节点的权值线段树。统计以该结点为根表示的线段树中是否存在该数字,不存在的话,统计比该数字小的个数及比该数字大的个数。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int N = 1e5+10;
 4 int cnt, root[N], tot;
 5 int m, n, q, w[N], v[N], X[N];
 6 vector<int> ve[N];
 7 struct node{
 8     int l, r, sum[2];
 9 };
10 node T[N*40];
11 struct pp{
12     int a, b, c;
13     pp(int a, int b, int c):a(a), b(b), c(c){}
14 };
15 void up(int &x, int y, int pos, int add, int l, int r){
16     T[++cnt] = T[y], T[cnt].sum[add]++, x = cnt;
17     if(l == r) return ;
18     int m = l+r >> 1;
19     if(pos <= m) up(T[x].l, T[y].l, pos, add, l, m);
20     else up(T[x].r, T[y].r, pos, add, m+1, r);
21 }
22 pp query(int x, int y, int l, int r){//< y left0, < y right1, > y
23     if(r <= y) return pp(T[x].sum[0], T[x].sum[1], 0);
24     if(l >= y) return pp(0, 0, T[x].sum[0]+T[x].sum[1]);
25     int m = l+r >> 1;
26     pp ret1 = query(T[x].l, y, l, m), ret2 = query(T[x].r, y, m+1, r);
27     return pp(ret1.a+ret2.a, ret1.b+ret2.b, ret1.c+ret2.c);
28 }
29 int query2(int rt, int x, int l, int r){//以rt为根的子树是否包含有x
30     if(l == r) return T[rt].sum[0]+T[rt].sum[1];
31     int m = l+r >> 1;
32     if(x <= m) return query2(T[rt].l, x, l, m);
33     else return query2(T[rt].r, x, m+1, r);
34 }
35 
36 void dfs(int f, int x, int lr){
37     if(f) up(root[x], root[f], w[f], lr, 1, tot);
38     for(int i = 0; i < ve[x].size(); i++){
39         int y = ve[x][i];
40         dfs(x, y, i);
41     }
42 }
43 int main(){
44     int t; scanf("%d", &t);
45     while(t--){
46         vector<int> tmp;
47         scanf("%d", &n);
48         for(int i = 1; i <= n; i++){
49             ve[i].clear();
50             scanf("%d", w+i);
51             tmp.push_back(w[i]);
52         }
53         scanf("%d", &m);
54         int u, a, b;
55         for(int i = 0; i < m; i++){
56             scanf("%d%d%d", &u, &a, &b);
57             ve[u].push_back(a), ve[u].push_back(b);
58         }
59         scanf("%d", &q);
60         for(int i = 0; i < q; i++){
61             scanf("%d%d", v+i, X+i);
62             tmp.push_back(X[i]);
63         }
64         sort(tmp.begin(), tmp.end());
65         tmp.erase( unique(tmp.begin(), tmp.end()), tmp.end() );
66         for(int i = 1; i <= n; i++)
67             w[i] = lower_bound(tmp.begin(), tmp.end(), w[i])-tmp.begin()+1;
68         for(int i = 0; i < q; i++)
69             X[i] = lower_bound(tmp.begin(), tmp.end(), X[i])-tmp.begin()+1;
70 
71         cnt = 0, tot = tmp.size();
72         dfs(0, 1, 0);
73 
74         for(int i = 0; i < q; i++){
75             if(query2(root[ v[i] ], X[i], 1, tot)){
76                 puts("0");
77                 continue ;
78             }
79             pp ret = query(root[ v[i] ], X[i], 1, tot);
80             //ret.a 1/8 ret.b 7/8 ret.c 1/2
81             int x = ret.b, y = 3*(ret.a+ret.b)+ret.c;
82             printf("%d %d
", x, y);
83         }
84     }
85     return 0;
86 }
View Code
原文地址:https://www.cnblogs.com/dirge/p/6095456.html