2.4.1 线段树

1.线段树的概念:

  线段树是擅长处理区间的,形如下图的数据结构。线段树是一颗完美二叉树(Perfect Binary Tree),树上的每个节点都维护一个区间。根维护的是整个区间,每个节点维护的是父亲的区间二等分后的其中一个子区间。当有n个元素时,对区间的操作可以在O(log n)的时间内完成。

                                                                    

根据节点中维护的数据的不同,线段树可以提供不同的功能。下面我们以实现了Range Minimum Query(RMQ)操作的线段树为例进行说明。

2.基于线段树的RMQ的结构

下面要建立的线段树在给定数列a0,a1,……,a(n-1)的情况下,可以在O(log n)时间内完成如下两种操作:

(1)给定s和t,求a(s),a(s+1),……,a(t)的最小值

(2)给定 i 和 x,把 ai 的值改成x

如下图,线段树的每个节点维护对应区间的最小值。在建树时,只需要按从下到上的顺序分别取左右儿子的值中较小者就可以了。

                                         

3.基于线段树的RMQ的查询

如果要求a0,……,a6的最小值。我们只需要求下图中的三个节点的值的最小值即可。

                                           

像这样,即使查询的是一个比较大的区间,由于较靠上的节点对应较大的区间,通过这些区间就可以知道大部分值的最小值,从而只需访问很少的节点就可以求得最小值。

要求某个区间的最小值,像下面这样递归处理就可以了。

  如果所查询的区间和当前节点对应的区间完全没有交集,那么就返回一个不影响答案的值(例如INT—MAX)。

  如果所查询的区间完全包含了当前节点对应的区间,那么就返回当前节点的值。

  以上两种情况都不满足的话,就对两个儿子递归处理,返回两个结果中的较小者。

4.基于线段树的RMQ的值的更新

在更新a0的值时,需要重新计算下图所示的4个节点的值。

                                                          

在更新ai的值时,需要对包含 i 的所有区间对应的节点的值重新进行计算。在更新时,可以从下面的节点开始向上不断更新,把每个节点的值更新为左右两个儿子的值的较小者就可以了。

5.基于线段树的RMQ的复杂度

不论哪种操作,对于每个深度都最多访问常数个节点。因此对于n个元素,每一次操作的复杂度是O(log n)。对于二叉搜索树,我们曾经提到过可能有因操作不当而导致退化的情况发生,从而使复杂度变得很糟糕。不过因为线段树不会添加或删除节点,所以即使是朴素的实现也都能在O(log n)时间内进行各种操作。

此外,n个元素的线段树的初始化的时间复杂度和总的空间复杂度都是O(n)。这是因为节点数是

n+n/2+n/4+……=2n。直觉上很容易让人产生复杂度是O(n log n)的错觉,需要注意。

6.基于线段树的RMQ的实现

为了简单起见,在建立线段树时,把数列所以的值都初始化为INT—MAX。此外,query的参数中不止传入节点的编号,还传入了节点对应的区间。

虽然从节点的编号也可以计算出对应的区间。但是把区间作为参数传入就可以节省这一步计算,为了简单起见,我们在实现中传入了对应的区间。

    #include<iostream>
    using namespace std;
    const int MAX_N = 1 << 17;
    int n, dat[2 * MAX_N - 1];
    int MAX = 100000;
    void init(int m){
      n = 1;
      while(n < m) n *= 2;
      for(int i = 0; i < 2*n-1; i++)
        dat[i] = MAX;
    }
    void update(int k, int a){
      k += n-1;
      dat[k] = a;
      while(k > 0){
        k = (k - 1) / 2;
        dat[k] = min(dat[k*2+1], dat[k*2+2]);
      }
      /*cout<<k<<" "<<a<<" "<<n <<endl;
      for(int i=0;i<15;i++)
          cout<<dat[i]<<" ";
        cout<<endl;*/
    }
    int query(int a, int b, int k, int l, int r){
      if(r <= a || b <= l) return MAX;
      if(a <= l && r <= b) return dat[k];
      else{
        int vl = query(a, b, k*2+1, l , (l+r)/2);
        int vr = query(a, b, k*2+2, (l+r)/2, r);
        //cout<<"vl= "<<vl<<" vr= "<<vr<<endl;
        return min(vl,vr);
      }
    }
    int main(){
      int y = 8;
      init(y);
      for(int i = 0; i < y; i++){
        int x;
        cin>>x;
        update(i,x);
      }
      //update(0,9);
      /*for(int i = 0; i < 15; i++)
        cout<<dat[i]<<" ";
      cout<<endl;*/
      int m=query(0, y, 0, 0, y);
      cout<<m<<endl;
      getchar();
      return 0;
    }
    //0 1 2 3 4 5 6 7
    //1 2 3
View Code

模板:

#include<iostream>
using namespace std;

struct Node{
    int l;
    int r;
    int maxvalue;
    int sum;
    int add;
}; 

Node a[100000];

// 初始化 区间[left, right], k:当前线段树位置 
void init(int left, int right, int k) {
    a[k].l = left;
    a[k].r = right;
    a[k].maxvalue = 0;
    a[k].sum = 0;
    a[k].add = 0;
    
    if (left != right) {
        int mid = (left + right) / 2;
        init(left, mid, 2 * k);
        init(mid + 1, right, 2 * k + 1);
    }
}

// 单点更新 i:当前线段树位置, k:目标位置 , value:更新值 
void update(int i, int k, int value) {
    if (a[i].l  == a[i].r ) {
        a[i].maxvalue = value;
        a[i].sum = value;
        return;
    }
    
    int mid = (a[i].l + a[i].r ) / 2;
    if (k <= mid) update(2 * i, k, value);
    else update(2 * i + 1, k, value);
    
    a[i].maxvalue = max(a[2 * i].maxvalue , a[2 * i + 1].maxvalue );
    a[i].sum = a[2 * i].sum + a[2 * i + 1].sum ;
}

//区间更新 i:当前位置,  更新区间[x, y], k: 区间同时操作值 
void update_add(int i, int x, int y, int k) {
    if (x == a[i].l  && y == a[i].r  ) {
        a[i].add += (y - x + 1) * k;
        return;
    }
    int mid = (a[i].l + a[i].r ) / 2;
    if(y <= mid) update_add(2 * i , x, mid, k);
    else if(x > mid) update_add(2 * i + 1, mid + 1, y, k);
    else {
        update_add(2 * i, x, mid, k);
        update_add(2 * i + 1, mid + 1, y, k);
    }
}

// 区间和 i:当前线段树位置, 查询区间[x, y] 
int query_sum(int i, int x, int y){
    if (x == a[i].l  && y == a[i].r )
        return a[i].sum  + a[i].add ;
        
    int mid = (a[i].l + a[i].r ) / 2;
    if (y <= mid) return query_sum(2 * i, x, y);
    else if (x > mid) return query_sum(2 * i + 1, x, y);
    else return query_sum(2 * i, x, mid) + query_sum(2 * i + 1, mid + 1, y);    
}

// 最大值  i:当前线段树位置, 查询区间[x, y]; 
int query_max(int i, int x, int y) {
    if (x == a[i].l && y == a[i].r ) return a[i].maxvalue ;
    
    int mid = (a[i].l + a[i].r ) / 2;
    if (y <= mid) return query_max(2 * i, x, y);
    else if (x > mid) return query_max(2 * i + 1, x, y);
    else return max (query_max(2 * i, x, mid), query_max(2 * i + 1, mid + 1, y));
}

int main() {
    int n, m;
    cin >> n >> m;
    init(1, n, 1);
    
    for (int i = 1; i <= n; i++) {
        int value;
        cin >> value;
        update(1, i, value);
    }
    
    //update_add(1, 1, 4, 1);
    //for (int i = 0; i < 10; i++) 
    //    printf("left: %d right: %d maxvalue: %d sum: %d add: %d
", a[i].l , a[i].r , a[i].maxvalue , a[i].sum , a[i].add );
    
    
    for (int i = 0; i < m; i++) {
        int op, x, y;
        cin >> op >> x >> y;
        if(op == 1) update(1, x, y);
        if(op == 2) cout << query_max(1, x, y) << endl;
        if(op == 3) cout << query_sum(1, x, y) << endl;
    }
    return 0;
}
View Code

7.需要运用线段树的问题

ALGO-8. 操作格⼦(线段树)
问题描述
有n个格⼦,从左到右放成⼀排,编号为1-n。
共有m次操作,有3种操作类型:
1.修改⼀个格⼦的权值,
2.求连续⼀段格⼦权值和,
3.求连续⼀段格⼦的最⼤值。
对于每个2、3操作输出你所求出的结果。
输⼊格式
第⼀⾏2个整数n,m。
接下来⼀⾏n个整数表示n个格⼦的初始权值。
接下来m⾏,每⾏3个整数p,x,y,p表示操作类型,p=1时表示修改格⼦x的权值为y,p=2时表示求区
间[x,y]内格⼦权值和,p=3时表示求区间[x,y]内格⼦最⼤的权值。
输出格式
有若⼲⾏,⾏数等于p=2或3的操作总数。
每⾏1个整数,对应了每个p=2或3操作的结果。
样例输⼊
4 3
1 2 3 4
2 1 3
1 4 3
3 1 4
样例输出
6
3
数据规模与约定
对于20%的数据n <= 100,m <= 200。
对于50%的数据n <= 5000,m <= 5000。
对于100%的数据1 <= n <= 100000,m <= 100000,0 <= 格⼦权值 <= 10000。
分析:⽤结构体数组建⽴⼀棵线段树~当p==1时从上到下更新这个线段树的值,当p==2的时候搜索对
应区间内的总和~当p==3的时候搜索对应区间的最⼤值

AC:

  

  #include<iostream>
  using namespace std;
  struct Node{
    int l;
    int r;
    int maxvalue;
    int sum;
  };
  Node a[100000];
  void init(int left, int right, int k){
    a[k].l = left;
    a[k].r = right;
    a[k].maxvalue = 0;
    a[k]. sum = 0;
    if(left != right){
      int mid = (left + right)/2;
      init(left, mid, 2*k);
      init(mid+1, right, 2*k+1);
    }
  }
  void update(int i, int k, int value){
    if(a[i].l == a[i].r){
      a[i].maxvalue = value;
      a[i].sum = value;
      return;
    }
    int mid = (a[i].l + a[i].r)/2;
    if(k <= mid) update(2*i, k, value);
    else update(2*i+1, k, value);
    a[i].maxvalue = max(a[2*i].maxvalue, a[2*i+1].maxvalue);
    a[i].sum = a[2*i].sum + a[2*i+1].sum;
  }
  int query_sum(int i, int x, int y){
    if(x == a[i].l && y == a[i].r) return a[i].sum;
    int mid = (a[i].l + a[i].r)/2;
    if(y <= mid) return query_sum(2*i, x, y);
    else if(x > mid) return query_sum(2*i+1, x, y);
    else return query_sum(2*i, x, mid) + query_sum(2*i+1, mid+1, y);
  }
  int query_max(int i, int x, int y) {
      if(x == a[i].l && y == a[i].r) {
          return a[i].maxvalue;
      }
      int mid = (a[i].l + a[i].r) / 2;
      if(y <= mid)
          return query_max(2*i, x, y);
      else if(x > mid)
          return query_max(2*i+1, x, y);
      else
          return max(query_max(2*i, x, mid), query_max(2*i+1, mid+1, y));
  }
  int main(){
    int n, m;
    cin >> n >> m;
    init(1, n, 1);
    for(int i = 1; i <= n; i++){
      int value;
      cin >> value;
      update(1, i, value);
    }
    for(int i = 0; i < m; i++){
      int op, x, y;
      cin >> op >> x >> y;
      if(op == 1) update(1, x, y);
      if(op == 2) cout << query_max(1, x, y) << endl;
      if(op == 3) cout << query_sum(1, x, y) << endl;
    }
    return 0;
  }
View Code
原文地址:https://www.cnblogs.com/astonc/p/10035263.html