【数据结构】树状数组

使用目的

树状数组是为了解决多次单点更新,区间查询等问题的数据结构。
树状数组的更新与查询复杂度均为O(logn)。

为了方便理解树状数组的优势,这里先给出一道题目:
给一大小固定的A数组,现用户可随意更改此数组的任何n个元素为任何值,且用户还想知道每次更改元素后数组中下标从0到m的元素的和。请你用快速的方法解决这个问题。
那么最简单的思路是在每次查询时从0到m做一次求和。但当更改和查询次数巨大的时候,我们不得不换一种思路以免超时。
那么树状数组就是可选的一种结构。

数据结构

为了解决问题,我们希望在求和的时候不要一个一个元素求和,而是一段一段求和。
因此,我们考虑设计一些节点存储一段元素的和,在用户求和的时候可以一段一段求。并且这个节点的数量应该恰到好处,因为我们后续更新这些节点必须快速方便。
庆幸的是有一种方便且快速的数据结构可供我们使用,那就是树状数组。

pic1
(这里引用了他人图片,原作者信息在图片上)

为了构造这样一个结构,我们先把数组A的一些元素“提”起来,做成图中的数组C的样子。
我们考虑把数组C看成上述的节点,这些节点存储着相应的一段元素的和。
为了方便找到其中的规律,把数字的二进制形式写出来:
0001 C[1] = A[1]
0010 C[2] = C[1] + A[2] = A[1] + A[2]
0011 C[3] = A[3]
0100 C[4] = C[2] + A[3] + A[4] = A[1] + A[2] + A[3] + A[4]
0101 C[5] = A[5]
0110 C[6] = C[5] + A[6] =A[5] + A[6]
0111 C[7] = A[7]
1000 C[8] = C[4] + C[6] + C[7] + A[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]
1001 C[9]= A[9]

先不管为什么这样写,我们先尝试求和,找出其中的规律:
sum[7] (0111) = C[7] (0111) + C[6] (0110) +C[4] (0100)
sum[5] (0101) = C[5] (0101) + C[4] (0100)
sum[9] (1001) = C[9] (1001) + C[8] (1000)

可以发现,sum[n]的值就是 二进制数n每次去掉最后一个1的数作为C的下标的元素和(原谅我的表达能力: -) )。
写成代码就是这样:

for (int i=n; i>0; i-=i的最后一位1)
    sum+=C[i];

这就是区间查询。

那么如果用户继续更新数组A中下标为x的元素的值怎么办?
我们只需按照x一直向上到最大的下标,一路加减数组C的值即可。
此时我们意识到这就像是刚才的查询一样,x每次加上最后一个1不就可以咯。
那么写成代码是这样:

for (int i=x; i<=n; i+=i的最后一位1)
    C[i]+=val;

那么最重要的事就是如何取得i的最后一位1。
学习过补码的同学们在此时占了优势,取得最后一位1只需 x&(-x)即可。(为什么?详情前往我之后写的一篇《如何理解原码、移码、反码与补码?》的最后一个表格)
设计代码:

int lowbit(int x){
    return x&(-x);
}

示例

我们之前没有给出完整的代码,这里给出代码,思想到位即可。
选用UVa12086(完整题解见另一随笔:)
其中getSum()即为从头求和,add()为单点更新。

#include <cstdio>
#include <cstring>
#define lowbit(x) ((x)&(-x))
const int max=200005;
long long arr[max], tree[max], n;
long long getSum(int x){
    long long sum=0;
    for (int i=x; i>0; i-=lowbit(i)) sum+=tree[i];
    return sum;
}

void add(int x, int val){
    for (int i=x; i<=n; i+=lowbit(i))
	tree[i]+=val;
}

int main(void){
    int T=0, a, b;
    while (scanf("%lld", &n)==1 && n){
	memset(tree, 0, sizeof(tree));
	for (int i=1; i<=n; i++){
	    scanf("%lld", &arr[i]);
	    add(i, arr[i]);
	}

	char str[5];
	if (T) printf("
");
	printf("Case %d:
", ++T);
	while (scanf("%s", str)==1){
	    if (str[0]=='E') break;
	    scanf("%d%d", &a, &b);
	    if (str[0]=='M') printf("%lld
", getSum(b)-getSum(a-1));
	    else if (str[0]=='S'){add(a, b-arr[a]); arr[a]=b;}
	}
    }

    return 0;
}
原文地址:https://www.cnblogs.com/tanglizi/p/7710417.html