树状数组入门讲解

平常我们会遇到一些对数组进行维护查询的操作,比较常见的,修改某点的值、求某个区间的和。

即给定一个n个元素的数组$A_1、A_2、..., A_n$,你的任务是设计一个数据结构,支持以下两种操作:

  1. $Add(x,d)$操作:让$A_x$增加$d$。
  2. $Query(L,R)$:计算$A_L+A_{L+1}+...+A_R$。

如果按简单的前缀和处理,修改操作是$O(1)$,区间查询操作是$O(n)$,当操作次数为m时,最坏的时间复杂度是$O(mn)$,$n$很大时显然无法接受。如何让$Query$和$Add$都能快速完成呢?有一种称为二叉搜索树($Binary Indexed Tree, BIT$)的数据结构(俗称树状数组),可以很好地解决这个问题。为此,我们需要先介绍$lowbit$。

lowbit

 对于正整数$x$,我们定义$lowbit(x)$为$x$的二进制表达式中最右边的1所对应的值(而不是这个比特的序号)。比如,38288的二进制是1001010110010000,所以$lowbit(38288)=16$(二进制是10000)。在程序实现中,$lowbit(x)=x&-x$。为什么呢?回忆一下,计算机里的整数采用补码表示,因此$-x$实际上是$x$按位取反末尾加一的结果,如图所示:

两者按位取与之后,前面的部分全部变0,之后lowbit保持不变。

原理

如下图所示是一颗典型的BIT,由15个结点组成,编号为1~15.

    灰色结点是BIT中的结点(白色长条的含义稍后叙述),每一层结点的lowbit相同,而且lowbit越大越靠近根。图中的虚线是BIT中的边(在代码中并不需要存储这些边,这里画出来只是为了更好的理解BIT)。注意编号为0的点是虚拟结点,它并不是树的一部分,但是它的存在可以让算法理解起来更容易一些。

    对于结点$i$,如果它是左子结点,那么它的父节点编号就是$i+lowbit(i)$;如果它是右子结点,那么它的父节点的编号是$i-lowbit(i)$(请自行验证)。搞清楚树的结构之后,构造一个辅助数组C,其中$C_i=A_{i-lowbit(i)+1}+A_{i-lowbit(i)+2}+...+A_i$

    换句话说,C中的每个元素都是A数组中的一段连续和。到底是哪一段呢?BIT中,每个灰色结点$i$都属于一个以它自身结尾的水平长条(对于lowbit=1的那些点,“长条”就是那个结点自身),这个长条中的数之和就是$C_i$。比如结点12的长条就是从9~12,即$C_2=A_9+A_{10}+A_{11}+A_{12}$。同理,$C_6=A_5+A_6$。这个等式及其重要,请花一些时间来验证"$C_i$就是以$i$结尾的水平长条内的元素之和"这一事实。

    有了$C$数组之后,如何计算前缀和$S_i$呢?顺着结点$i$往左走,边走边往上爬(注意并不一定沿着树中的边往爬),把沿途经过的$C_i$累加起来就可以了(请自行验证,沿途经过的$C_i$所对应的长条不重复不遗漏地包含了所有需要累加地元素),如图所示

    而如果修改了一个$A_i$,需要更新$C$数组中哪些元素呢?顺着结点$C_i$开始往右走,边走边“往上爬”(同样不一定沿着树中的边爬),沿途修改所有结点对应的$C_i$即可(请自己验证,有且仅有这些结点对应的长条包含被修改的元素),如图所示:

   不难证明。两个操作的时间复杂度均为O(logn)。预处理的方法是先把$A$数组和$C$数组清空,然后执行$n$次$add$操作,总时间复杂度为$O(nlogn)$。

代码

两个操作的代码如下:

int sum(int x)     //前缀和
{
    int ret = 0;
    while (x > 0)
    {
        ret += C[x];
        x -= lowbit(x);
    }
    return ret;
}
void add(int x, int d)
{
    while (x <= n)
    {
        C[x] += d;
        x += lowbit(x);
    }
}

完整代码:

 1 #include<cstdio>
 2 #include<algorithm>
 3 #include<cstring>
 4 using namespace std;
 5 
 6 const int maxn = 10000 + 10;
 7 int a[maxn],C[maxn],n;
 8 
 9 int lowbit(int x)
10 {
11     return x & -x;
12 }
13 int sum(int x)
14 {
15     int ret = 0;
16     while (x > 0)
17     {
18         ret += C[x];
19         x -= lowbit(x);
20     }
21     return ret;
22 }
23 void add(int x, int d)
24 {
25     while (x <= n)
26     {
27         C[x] += d;
28         x += lowbit(x);
29     }
30 }
31 void init()
32 {
33     memset(C, 0, sizeof(C));
34     for (int i = 1; i <= n; i++)
35         add(i, a[i]);
36 }
37 
38 int main()
39 {
40     scanf("%d", &n);
41     for (int i = 1; i <= n; i++)  scanf("%d", &a[i]);
42     init();
43     printf("%d
", sum(10));
44     printf("%d
", sum(5));
45     add(5, 3);
46     printf("%d", sum(5));
47 
48     return 0;
49 }
View Code
原文地址:https://www.cnblogs.com/lfri/p/10655077.html