[SinGuLaRiTy] ZKW线段树

【SinGuLaRiTy-1007】 Copyrights (c) SinGuLaRiTy 2017. All Rights Reserved.

关于ZKW线段树

Zkw线段树是清华大学张昆玮发明非递归线段树的写法。实践证明,这种线段树常数更小,速度更快,写起来也并不复杂。

建树

ZKW线段树本质上就是依赖于满二叉树中父节点与子节点的编号关系。

如上图中的一个简单的满二叉树,我们可以发现如下规律:

1>父子节点编号关系: 假设父节点的编号为 n ,那么,它的两个子节点的编号就分别为 n*2(n<<1)、n*2+1(n<<1|1);

2>二叉树层数与底层叶子节点数的关系:假设这个二叉树的层数为 m ,那么,这个二叉树的底层叶子节点数(由于是满二叉树,这也就是所有的叶子节点了)就是2^(m-1),同时,我们还可以知道,所有叶子节点中编号最小的,即在这个满二叉树左下角的叶子节点的编号也为 2^(m-1);

通过以上的两大关系,我们在存储一个数组的初始数据时,就可以直接将初始数据存储在满二叉树的底层。假设数组中有 x 个元素,那么这 x 个元素在这个满二叉树中的编号就是2^(m-1)~2^(m-1)+x-1,访问起来就很方便了。

于是就有了建树代码:(其中n代表的是初始数组中的元素个数,M代表的是最底层的叶子节点个数)

inline void Build()
{
    for(M=1;M<n;M<<=1) ;//由于要构建一个满二叉树,所以我们不能直接让二叉树的叶子节点数等于元素个数,M可能会大于n;本层循环使底层叶子节点数在满足“满二叉树的前提下最小”
    for(int i=M;i<n+M;i++)//由于M也同样是本层最左侧叶子节点的编号,所以直接从这里开始赋值
        Tree[i]=Read();
}

(有些博客总是会在这里自问自答“建完了吗?没有。”,对于这种有点SB的行为,我表示无法理解)

不过确实,到目前为止,建树还未完成,我们还需要从下往上更新其它节点的值。当然,知道了父子节点编号的关系,这个操作就非常好用了。

inline void upgrade()
{
    for(int i=M-1;i;i--)
    {
        Tree[i]=Tree[i<<1|1]+Tree[i<<1];//维护为区间和
    }
}

当然,你也可以将其维护为最大值,最小值之类的,代码都大同小异:

<最大值>

Tree[i]=max(Tree[i<<1|1],Tree[i<<1]);

<最小值>

Tree[i]=min(Tree[i<<1|1],Tree[i<<1]);

到目前为止,我们才算是完成了ZKW线段树的建树工作。

<ZKW线段树中的差分思想>

在建ZKW线段树的过程中,可以用的Tree[i]表示父子节点的差值,也同样可以达到存储数据的目的。

void Build(int n)
{ 
    for(M=1;M<=n+1;M<<=1);
    for(int i=M;i<M+n;i++)
        Tree[i]=in();
    for(int i=M-1;i;--i) 
        Tree[i]=min(Tree[i<<1],Tree[i<<1|1]),Tree[i<<1]-=Tree[i],Tree[i<<1|1]-=Tree[i];
}

觉得稍微复杂了一些?但这样的存储方式可以为RMQ问题做准备。

<关于空间>

我们都知道,在建线段树时,需要开的数组(或是结构体)的大小是 4n ;在这里 , 我们来计算一下ZKW线段树的所需要的空间。(设初始数据中元素个数为 n )

最好的情况: 当 n=2^k 时,由于此时刚好可以把最底层排满,则数组大小大概为 2n ;

最坏的情况: 当 n=2^k+1时,即底层刚好多出一个,仍需要把底层排满时,则数组大小大概为 4n-1 ;

因此,即使是最坏的情况,ZKW线段树也比普通线段树的空间表现要好。

单点查询

假设数组中有 x 个元素,二叉树层数为 m ,那么这 x 个元素在这个满二叉树中的编号就是2^(m-1)~2^(m-1)+x-1,访问起来很方便。

<单点查询-差分版>

其实差分版单点查询写起来也不是很复杂,也比较利于理解。

void Sum(int x,int res=0)
{ 
    while(x) 
        res+=Tree[x],x>>=1;
    return res;
}

区间查询

<区间求和>

先看一下代码:

int Sum(int s,int t,int Ans=0)
{ 
    s+=M-1,t+=M-1;
    Ans+=d[s]+d[t]; 
    if(s==t)
        return Ans-=d[s];
    for(;s^t^1;s>>=1,t>>=1)//s^t^1 : 如果s与t在同一个父亲节点以下,就说明我们已经累加到这棵树的根部了。当s与t在同一个父亲节点下时,t-s=1,那么s^t=1,s^t^1=0,此时就退出循环。
    {
        if(~s&1)//这里在下面解释
            Ans+=d[s^1]; //d[s^1]是d[s]的兄弟
        if(t&1)//这里在下面解释
            Ans+=d[t^1];//d[t^1]是d[t]的兄弟
    }
    return Ans; 
}

<*关于代码中的 ~s&1 与 t&1>

首先我们可以将这两个位运算式转化为好理解一点的式子:

if(~s&1) ->  if(s%2==0)

if(t&1) -> if(t%2!=0)
也就是说,这里是在判断奇偶,结合满二叉树的编号规律我们很容易发现:若编号为奇,则为右儿子;若编号为偶,则为左儿子。那么,这里判断左/右儿子有什么用呢?

我们看上面的这幅图。如果我们知道要查询的区间的两个端点为编号4、7的节点,由于这是满二叉树,因此我们可以在图中寻找位于4号节点右边且位于7号节点左边的节点,这些节点一定位于我们要查询的区间之中。而我们又知道,在两个兄弟节点A,B之中,若A为左儿子,那么B一定在A的右边;若A为右儿子,那么B一定在A的左边。也就是说,如果我们知道区间的两个端点是左儿子还是右儿子,我们就可以知道它们的兄弟节点是否在区间的覆盖范围之内。又由于在ZKW线段树中,我们已经将父节点维护成为包含其子节点信息的节点,因此不用担心有漏算的情况。(要注意是开区间还是闭区间)

我们不妨画个图来验证一下:

(注:图中的红点为累加过的点,橙色为访问过的点)

图中的累加节点覆盖了所有的查询范围。

<区间查询最大值>

和 区间求和 的代码思路差不多,直接上代码:

void Sum(int s,int t,int L=0,int R=0)
{ 
    for(s=s+M-1,t=t+M-1;s^t^1;s>>=1,t>>=1)
    { 
        L+=d[s],R+=d[t]; 
        if(~s&1) L=max(L,d[s^1]);
        if(t&1) R=max(R,d[t^1]); 
    } 
    int res=max(L,R);
    while(s) res+=d[s>>=1]; 
}

<区间查询最小值>

void Add(int s,int t,int v,int A=0)
{
    for(s=s+M-1,t=t+M-1;s^t^1;s>>=1,t>>=1)
    {
        if(~s&1) d[s^1]+=v;
        if(t&1) d[t^1]+=v;
        A=min(d[s],d[s^1]);d[s]-=A,d[s^1]-=A,d[s>>1]+=A;
        A=min(d[t],d[t^1]);d[t]-=A,d[t^1]-=A,d[t>>1]+=A;
    }
    while(s) A=min(d[s],d[s^1]),d[s]-=A,d[s^1]-=A,d[s>>=1]+=A;
}

单点更新

void Change(int x,int v) 
{ 
    d[x=M+x-1]+=v; 
    while(x) 
        d[x>>=1]=d[x<<1]+d[x<<1|1];
 }

区间更新

举个模板题例子。结合题目来看看代码吧。

区间修改的RMQ问题

题目描述

给出N(1 ≤ N ≤ 50,000)个数的序列A,下标从1到N,每个元素值均不超过1000。有两种操作:

(1)  Q i j:询问区间[i, j]之间的最大值与最小值的差值

(2) C i j k:将区间[i, j]中的每一个元素增加k,k是一个整数,k的绝对值不超过1000。

一共有M (1 ≤ M ≤ 200,000) 次操作,对每个Q操作,输出一行,回答提问。

输入

第1行:2个整数N, M
第2行:N个元素
接下来M行,每行一个操作

输出

对每个Q操作,在一行上输出答案

样例输入 样例输出
5 4
1 2 3 4 5
Q 2 4
C 1 1 1
C 1 3 2
Q 1 5
2
1

 

 

 

 

 

 

代码:

#include<cstdio>
#include<algorithm>
using namespace std;
 
#define lson pos << 1
#define rson pos << 1 | 1
#define fa(x) (x >> 1)
const int MAXN = 50000;
int d1[MAXN << 2], d2[MAXN << 2], M = 1, n, m;
// d1 -> max // d2 -> min
 
inline void Read(int &Ret){
    char ch; int flg = 1;
    while(ch = getchar(), ch < '0' || ch > '9')
        if(ch == '-') flg = -1;
    Ret = ch - '0';
    while(ch = getchar(), ch >= '0' && ch <= '9')
        Ret = Ret * 10 + ch - '0';
    Ret *= flg;
}
 
void build(int n){
    while(M < n) M <<= 1;
    int pos = M --;
    while(pos <= M + n){
        Read(d1[pos]);
        d2[pos] = d1[pos];
        pos ++;
    }
    pos = M;
    while(pos){
        d1[pos] = max(d1[lson], d1[rson]);
        d2[pos] = min(d2[lson], d2[rson]);
        d1[lson] -= d1[pos]; d1[rson] -= d1[pos];
        d2[lson] -= d2[pos]; d2[rson] -= d2[pos];
        pos --;
    }
}
 
inline void Insert(int L, int R, int v){//区间更新
    L += M; R += M;
    int A;
    if(L == R){
        d1[L] += v; d2[L] += v;
        while(L > 1){
            A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A;
            A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A;
            L >>= 1;
        }
        return;
    }
    d1[L] += v; d1[R] += v; d2[L] += v; d2[R] += v;
    while(L ^ R ^ 1){
        if(~L & 1) d1[L ^ 1] += v, d2[L ^ 1] += v;
        if(R & 1) d1[R ^ 1] += v, d2[R ^ 1] += v;
        A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A;
        A = max(d1[R], d1[R ^ 1]); d1[R] -= A; d1[R ^ 1] -= A; d1[fa(R)] += A;
        A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A;
        A = min(d2[R], d2[R ^ 1]); d2[R] -= A; d2[R ^ 1] -= A; d2[fa(R)] += A;
        L >>= 1; R >>= 1;
    }
    while(L > 1){
        A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A;
        A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A;
        L >>= 1;
    }
}
 
inline int getans(int L, int R){
    L += M; R += M;
    int ans1 = 0, ans2 = 0;
    if(L == R){
        while(L){
            ans1 += d1[L]; ans2 += d2[L];
            L >>= 1;
        }
        return ans1 - ans2;
    }
    int l1 = 0, r1 = 0, l2 = 0, r2 = 0;
    while(L ^ R ^ 1){
        l1 += d1[L]; r1 += d1[R];
        l2 += d2[L]; r2 += d2[R];
        if(~L & 1) l1 = max(l1, d1[L ^ 1]), l2 = min(l2, d2[L ^ 1]);
        if(R & 1) r1 = max(r1, d1[R ^ 1]), r2 = min(r2, d2[R ^ 1]);
        L >>= 1; R >>= 1;
    }
    l1 += d1[L]; r1 += d1[R]; l2 += d2[L]; r2 += d2[R];
    ans1 = max(l1, r1); ans2 = min(l2, r2);
    while(L > 1){
        L >>= 1;
        ans1 += d1[L]; ans2 += d2[L];
    }
    //printf("max=%d min=%d
",ans1, ans2);
    return ans1 - ans2;
}
 
int main(){
    int a, b, c;
    char id[3];
    Read(n); Read(m);
    build(n);
    while(m --){
        scanf("%s",id);
        Read(a); Read(b);
        switch(id[0]){
            case 'C': Read(c), Insert(a, b, c); break;
            default: printf("%d
",getans(a, b));
        }
    }
    return 0;
}

Time: 2017-03-11

 

原文地址:https://www.cnblogs.com/SinGuLaRiTy2001/p/6591718.html