权值线段树 简单总结 相关例题

一、简单定义

  本质上仍然是一棵线段树,但它和普通线段树不同,其每个节点用来表示一个区间内元素出现的次数,可以理解为维护区间的值域。

二、应用

   1.维护一段区间的数出现的次数,快速计算一段区间的数的出现次数。

   2.快速找到第k大或第k小值。 

   缺点:只能离线操作,不能进行在线询问。

三、原理

   例如,最初有一个序列 7 2 3 5 6 1 4,线段树初始状态各个结点的值都是0。如下图:

依次先插入7,线段树变为下图

然后再插入元素2,此时根节点更新为2,根节点维护了1~7数字出现的个数,出现了一次7和一次2,总次数那自然是2了

再插入元素3,同样地去递归更新结点,如下图所示

此时,当更新元素3的个数之前,我们可以查询[1,2]的权值和区间[4,7]的权值,可以得到比3小的元素有几个,比3大的元素有几个,那么可以轻易地得出当前出现的元素3是第k小和第k大

四、相关代码

1.单点更新。依然是递归到叶子节点p,令t[p]++

1 void upd(int l,int r,int x,int p){
2     if(l==r) {t[p]++;return;}
3     int mid = (l+r)>>1;
4     if(x<=mid) upd(l,mid,x,p<<1);
5     else upd(mid+1,r,x,p<<1|1);
6     t[p] = t[p<<1]+t[p<<1|1]; 
7 }
View Code

2.查询一段区间[l,r]数字出现的总和

1 int query(int ql,int qr,int l,int r,int p){
2     if(l>=ql && r<=qr) return t[p];
3     int mid = (l+r)>>1;
4     int ans = 0;
5     if(ql<=mid) ans+=query(ql,qr,l,mid,p<<1);
6     if(qr>mid) ans+=query(ql,qr,mid+1,r,p<<1|1);
7     return ans;
8 }
View Code

3.查询所有数中的第k大(第k小)

 1 int kth(int l,int r,int k,int p)
 2 {
 3     if(l == r) return l;
 4     else{
 5         int mid = l+r>>1;
 6         int s1 = t[p<<1],s2 = t[p<<1|1];
 7         if(k<=s2) return kth(mid+1,r,p<<1|1,k);
 8         else return kth(l,mid,p<<1,k - s2); 
 9     }
10 }
View Code

五、例题

以查询第k大为例,权值线段树的核心是到每个结点,如果右子树的权值总和大于了k,则说明其第k大值在右子树,递归进入右子树。反之则说明第k大值在左子树。

特别注意:若要进入左子树,需要k减去右子树的总和,比如要找的元素是第5大,右子树权值总和为3,则需5-3=2,说明该节点的第5大值存在于左子树的第2大值中。那么从左子树递归下去,直到递归到一个数,那就是答案了。

1.hdu1394 Minimum Inversion Number

http://acm.hdu.edu.cn/showproblem.php?pid=1394


在先询问逆序对个数的最小值。给定一个序列,每次把序列的第一个数移动到最后,求每次操作后新序列的逆序对个数的最小值。
首先元素的数据范围不大,不必离散化,直接开权值线段树,结点维护元素的个数。离线,每次输入一个数,查询操作求出比其大的数字的个数,然后做单点更新。
再for一遍序列,每次把元素a[i]移动到最后,新增的答案贡献等于ans先减去比a[i]小的元素(因为把a[i]要放置最后),再加上比a[i]大的元素,这样更新下去,取min即可。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int maxn = 1e4 + 5;
 5 int t[maxn<<2];
 6 int a[maxn];
 7 int n;
 8 void upd(int l,int r,int x,int p){
 9     if(l==r) {t[p]++;return;}
10     int mid = (l+r)>>1;
11     if(x<=mid) upd(l,mid,x,p<<1);
12     else upd(mid+1,r,x,p<<1|1);
13     t[p] = t[p<<1]+t[p<<1|1]; 
14 }
15 int query(int ql,int qr,int l,int r,int p){
16     if(l>=ql && r<=qr) return t[p];
17     int mid = (l+r)>>1;
18     int ans = 0;
19     if(ql<=mid) ans+=query(ql,qr,l,mid,p<<1);
20     if(qr>mid) ans+=query(ql,qr,mid+1,r,p<<1|1);
21     return ans;
22 }
23 int main() {
24     while(~scanf("%d",&n)){
25         memset(t,0,sizeof(t));
26         int x,ans=0;
27         for(int i=1;i<=n;++i){
28             scanf("%d",&a[i]);
29             ans+=query(a[i]+1,n,1,n,1);
30 //            cout<<ans<<" ";
31             upd(1,n,a[i]+1,1);
32         }int res=ans;
33         for(int i=1;i<=n;++i){
34             res-=query(1,a[i]+1,1,n,1)-1;
35             res+=query(a[i]+1,n,1,n,1)-1;
36 //            cout<<res<<" ";
37             ans=min(ans,res);
38         }printf("%d
",ans);
39     }
40     return 0;
41 }
View Code

2.洛谷P1637 三元上升子序列

https://www.luogu.com.cn/problem/P1637
求ai<aj<ak的三元组个数
首先每个元素的数据范围很大在longlong范围内,需要离散化处理一下,否则线段树MLE。
开权值线段树,一遍从1到n遍历维护一个权值线段树,用来预处L[i](左边比ai小的元素个数)。一遍从n到1遍历维护一个权值线段树,预处理R[i](右边比ai大的元素的个数)。
最后再for一遍a数组,根据乘法原理对于ai其组成的三元组为R[i]*L[i],整体求出∑R[i]*L[i]即可。具体请看代码

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int maxn = 3e4 + 5;
 5 int t[maxn<<2];
 6 ll a[maxn],tmp[maxn],L[maxn],R[maxn];
 7 int n;
 8 void upd(int l,int r,int x,int p){
 9     if(l==r) {t[p]++;return;}
10     int mid = (l+r)>>1;
11     if(x<=mid) upd(l,mid,x,p<<1);
12     else upd(mid+1,r,x,p<<1|1);
13     t[p] = t[p<<1]+t[p<<1|1]; 
14 }
15 int query(int ql,int qr,int l,int r,int p){
16     if(l>=ql && r<=qr) return t[p];
17     int mid = (l+r)>>1;
18     int ans = 0;
19     if(ql<=mid) ans+=query(ql,qr,l,mid,p<<1);
20     if(qr>mid) ans+=query(ql,qr,mid+1,r,p<<1|1);
21     return ans;
22 }
23 
24 
25 int main() {
26     scanf("%d",&n);
27     for(int i = 1;i<=n;i++){
28         scanf("%d",&a[i]);
29         tmp[i] = a[i];
30     }
31     sort(tmp+1,tmp+1+n);
32     int up = unique(tmp+1,tmp+1+n) - (tmp+1);
33     unordered_map<ll,int> m;
34     for(int i = 1;i<=up;i++) m[tmp[i]] = i;
35     
36     for(int i = 1;i<=n;i++){
37         int pos = m[a[i]];//离散化后的大小 
38         if(pos!=1) L[i] = query(1,pos-1,1,up,1);
39         upd(1,up,pos,1);
40     }
41     memset(t,0,sizeof(t));
42     for(int i = n;i>=1;i--){
43         int pos = m[a[i]];
44         if(pos!=up) R[i] = query(pos+1,up,1,up,1);
45         upd(1,up,pos,1);
46     }
47     ll ans = 0;
48     for(int i = 1;i<=n;i++){
49         ans+=(L[i]*R[i]);
50     }
51     printf("%lld",ans);
52     return 0;
53 }
54 //
View Code

3.hdu4217 Data Structure?

http://acm.hdu.edu.cn/showproblem.php?pid=4217

每次查询第k小,从序列中拿出,求所有询问的总和。权值线段树的板子题,直接开权值线段树,每次查询第k小,随后单点更新删除即可。

 1 #include<cstring>
 2 #include<iostream>
 3 #include<cstdio>
 4 using namespace std;
 5 typedef long long ll;
 6 const int maxn = 265000;
 7 int t[maxn<<2];
 8 int n;
 9 void build(int l,int r,int p){
10     if(l == r) {t[p] = 1;return;}
11     int mid = l+r>>1;
12     build(l,mid,p<<1);
13     build(mid+1,r,p<<1|1);
14     t[p] = t[p<<1] + t[p<<1|1]; 
15 }
16 void upd(int l,int r,int x,int p,int v){
17     if(l==r) {t[p]+=v;return;}
18     int mid = (l+r)>>1;
19     if(x<=mid) upd(l,mid,x,p<<1,v);
20     else upd(mid+1,r,x,p<<1|1,v);
21     t[p] = t[p<<1]+t[p<<1|1]; 
22 }
23 int query(int ql,int qr,int l,int r,int p){
24     if(l>=ql && r<=qr) return t[p];
25     int mid = (l+r)>>1;
26     int ans = 0;
27     if(ql<=mid) ans+=query(ql,qr,l,mid,p<<1);
28     if(qr>mid) ans+=query(ql,qr,mid+1,r,p<<1|1);
29     return ans;
30 }
31 int findkth(int l,int r,int k,int p)
32 {
33     if(l == r) return l;
34     else{
35         int mid = l+r>>1;
36         int s1 = t[p<<1],s2 = t[p<<1|1];
37         if(k<=s1) return findkth(l,mid,k,p<<1);
38         else return findkth(mid+1,r,k - s1,p<<1|1); 
39     }
40 }
41 int main() {
42     int T;
43     scanf("%d",&T);
44     int cnt = 1;
45     while(T--){
46         int n,k;
47         scanf("%d%d",&n,&k);;
48         build(1,n,1);
49         ll sum = 0; 
50         for(int i = 1;i<=k;i++){
51             int kth;
52             scanf("%d",&kth);
53             int take = findkth(1,n,kth,1);
54             sum +=take;
55             upd(1,n,take,1,-1);
56         }
57         printf("Case %d: %lld
",cnt,sum);
58         cnt++;
59     }
60     return 0;
61 }
62 //
View Code
原文地址:https://www.cnblogs.com/AaronChang/p/12650589.html