浅谈 线段树

其实真不太想写的。

Update:2020.11.23 例题。

前置芝士:可以先看看其他的 RMQ 算法,如树状数组,ST之类的。

毕竟线段树这个东西很重要,而看懂了其他的会方便理解一些,早学不一定精。

当然,也欢迎翻翻我之前写的。给个链。

0x01 基本概念

线段树(Segment Tree)是一个基于分治の数据结构。

通常处理区间,序列中的查询,更改问题。但因为其可维护变量的多样性,所以常在各类题目中遇到。准确说,是各类优化中遇到。

线段树是个有根二叉树,我们记为 (t),其每个节点 (t[p]) 均保存着一组关键信息:(l)(r)。我通常将其称为钥匙信息。(意会即可)

它们合在一起表示当前这个节点保存的是哪一个区间的信息。

比如,如果你的线段树维护的有区间最值,那么对于一个结点 (t[p]) 满足 (t[p].l = L, t[p].r = R),则其维护的值就是 (max{a[L], a[L + 1]...a[R - 1], a[R]}),其中 (a[i]) 为给定区间的第 (i) 号元素。

而对于每个节点 (t[p]),我们定义 (mid = frac {(t[p].l + t[p].r)} {2}) ,并定义,对于其左儿子节点 (t[p imes 2]),有 (t[p imes 2].l = t[p].l)(t[p imes 2].r = mid)

同理对于其右儿子 (t[p imes 2 + 1]),有 (t[p imes 2 + 1].l = mid + 1)(t[p imes 2 + 1].r = t[p].r)

这样,在我们知道所有的叶子节点后就可以往上更新至全部的区间信息,这就是线段树的大体思想。

不过我们还需要完善最后一个定义。对于叶子节点 (t[q]),显然拥有一个特性:没有儿子。那也就是说它们无法找到合理的 (mid) 将其表示的区间再次分化,那就是 (t[q].l = t[q].r) 呗。我们在序列里称这样的区间为元区间,它也是我们的边界,与我们的答案来源。毕竟这是一开始题目就会给你的东西。

我不想画图。。。

但为了方便读者理解,我还是扒一个下来吧。(doge

上面这张图中,线段树起到了保存区间最小值的作用。

其原序列 (a) 为:

1 3 5 7 9 10 2 4 8 6

节点上标注的 ([x, y]) 表示 (b_i(i in [x, y], b_i = a_i)) 这个序列,而 (minv[p])(Min Value),表示节点 (t[p]) 保存的 (b) 序列的最小值。

很容易理解了吧,接下来我们来看看实现。

0x02 代码实现

没点基础码力还真不敢碰这玩意。

注:这里的实现均以区间和为例。

step0.前言:关于空间

线段树如果不看所有的叶子节点,它一定是颗满二叉树,这个毋庸置疑。

那么我们设除去最下面一层以外的深度最深的一层有 (x) 个节点。

则上一层定有 (frac x 2) 个节点,毕竟两个儿子对应一个父亲嘛。

如果我们设最下面一层的节点数为 (y),那么整棵线段树的节点总和为:(y + x + frac x 2 + frac x 4 +...+ 1)。等比数列求和,答案为 (2 imes n - 1 + y)。又因为最下层的节点均为元区间,且原序列中元区间个数为 (n) 个,那么所以 (y < n)。所以总节点数一定小于 (3 imes n)

但是因为我们的存图方式是有空点存在的,详见上图。且最后一行最多容纳 (2 imes n) 个点,那么我们线段树的数组就需要开足四倍空间。

这也是线段树的一个小缺陷所在。

step1.建树

首先,我们已经了解了线段树的构造,那么这一步就相当于是去模拟了。

再理一下。

  • 显然,(t[p].Sum = t[p imes 2].Sum + t[p imes 2 + 1].Sum)
  • (l = r) 时,属于元区间,来到叶子节点,直接更新,并返回。
  • 每次记得保存每个节点的钥匙信息,即 (l)(r) 的值,并从 (mid) 开始继续往下划分。
void Make_Tree(int p, int l, int r) {
	t[p].l = l;
	t[p].r = r;
    // 记录节点の钥匙信息。
	if(l == r) {
		t[p].Sum = a[l];
		return ;
	}
    // 叶子节点(元区间)
	int mid = (l + r) >> 1;
	Make_Tree(p << 1, l, mid);
	Make_Tree(p << 1 | 1, mid + 1, r);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum; 
    // 递归建树并维护区间和
}

step2.更新

这里的例子中我们完成操作:单点修改。

即一次修改一个原序列中一个元素的值。记我们要改的元素为 (a_index),表示它在原序列 (a) 中的第 (index) 位。

那么在线段树中改一个点,其实就是找到其对应的元区间,更改元区间后再更新它的所有父亲嘛。

首先明确,对于节点 (t[p]),如果 (index <= mid),则我们需要往左儿子找,因为根据定义,此时的 (index) 一定属于区间 ([t[p imes 2].l, t[p imes 2].r]) 中。那么反之,如果 (index > mid),则需要往右儿子找。

始终记住,在大多情况下,元区间都是我们的边界条件。

不就结了吗。

void Update(int p, int index, int x) {
	if(t[p].l == t[p].r) {
		t[p].Sum += x;
		return ;
	}
    // 找到元区间,直接修改
	int mid = (t[p].l + t[p].r) >> 1;
	if(index <= mid)
		Update(p << 1, index, x);
	else
		Update(p << 1 | 1, index, x);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum; 
    // 左右儿子依次访问,并再次更新
}

step3.查询

单点修改,单点查询显然没必要对吧。那个不是有手就行吗。

于是我们考虑单点修改,区间查询。

首先对于我们想要的区间 ([L, R]),如果 ([L, R]) 完全覆盖一个区间 ([t[p].l, t[p].r]),则 (t[p].Sum) 一定会对我们想要的答案产生价值。

那么这次我们就不需要再查询 (t[p]) 的儿子节点了,因为它的儿子节点可以带来的价值一定全部包含在 (t[p].Sum) 中。

但是同样,对于 (t[p]),也完全有可能不完全覆盖。 那么我们就暂时不能累加价值,因为这时候 (t[p]) 的有部分价值可能是我们不想要的。

那这个时候,我们就应该去访问 (t[p]) 的儿子节点了,必定在整棵线段树中一定能找到完全覆盖的情况。

真就往下搜呗。

不过不一定是两个儿子都需要访问。大多时候我们都是需要访问才访问。需要访问也就是说要查的区间与某个儿子表示的区间拥有交集。那就直接比较 (L)(R)(mid) 的大小即可。如果 (L <= mid),即我们要查的区间有一丢丢在左儿子里,那么就去拜访它。同理,如果 (R > mid),则还需去看看右儿子。当然也可能会出现两边都要访问的情况,也就是 (mid) 刚好把要查区间从中截断(两边各有一丢丢嘛)。

LL Query(int p, int l, int r) {
	if(l <= t[p].l && t[p].r <= r) 
		return t[p].Sum;
	// 完全覆盖就直接返回
	int mid = (t[p].l + t[p].r) >> 1;
	LL val = 0;
	if(l <= mid)
		val += Query(p << 1, l, r);
	if(r > mid)
		val += Query(p << 1 | 1, l, r);
	// 看情况访问左右儿子
	return val;
}

step4.完整代码

还是放一个吧。线段树一定要慢慢调哦。

可以交交这道题。「线段树」模板题1

#include <cstdio>

const int MAXN = 1e6 + 5;
const int MAXT = 1e6 * 4 + 5;
typedef long long LL;
struct Segment_Tree {
	int l, r;
	LL Sum;
	Segment_Tree() {}
	Segment_Tree(int L, int R, LL S) {
		l = L;
		r = R;
		Sum = S;
	}
} t[MAXT];
int a[MAXN];

void Make_Tree(int p, int l, int r) { // 建树
	t[p].l = l;
	t[p].r = r;
	if(l == r) {
		t[p].Sum = a[l];
		return ;
	}
	int mid = (l + r) >> 1;
	Make_Tree(p << 1, l, mid);
	Make_Tree(p << 1 | 1, mid + 1, r);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum; 
}

void Update(int p, int index, int x) { // 更新
	if(t[p].l == t[p].r) {
		t[p].Sum += x;
		return ;
	}
	int mid = (t[p].l + t[p].r) >> 1;
	if(index <= mid)
		Update(p << 1, index, x);
	else
		Update(p << 1 | 1, index, x);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum; 
}

LL Query(int p, int l, int r) { // 询问
	if(l <= t[p].l && t[p].r <= r) 
		return t[p].Sum;
	int mid = (t[p].l + t[p].r) >> 1;
	LL val = 0;
	if(l <= mid)
		val += Query(p << 1, l, r);
	if(r > mid)
		val += Query(p << 1 | 1, l, r);
	return val;
}

int main() {
	int n, q;
	scanf ("%d %d", &n, &q);
	for(int i = 1; i <= n; i++)
		scanf ("%d", &a[i]);
	Make_Tree(1, 1, n);
	for(int i = 1; i <= q; i++) {
		int flag;
		scanf ("%d", &flag);
		if(flag == 1) {
			int v, x;
			scanf ("%d %d", &v, &x);
			Update(1, v, x);
		}
		else {
			int l, r;
			scanf ("%d %d", &l, &r);
			printf("%lld
", Query(1, l, r));
		}
	}
	return 0;
}

其实话说,整体把线段树的代码拉出来还挺好看的。

上一个有这种感觉的是笛卡尔曲线方程:(r = a(1 - sin Theta))。抱歉扯远了。

0x03 推广

我们在上一个版块中只讲到了单点修改。

那如果想要区间修改呢?之前知道的只有树状数组有个非常麻烦的实现方法(雾。

于是引入:懒惰标记 (Lazy Tag),又叫延迟标记 (Daley Tag)。

首先,关于区间修改,我们可以按照刚刚的思路将这个区间修改改为很多个小的单点修改,但这显然会超时。那么考虑优化。

你会发现如果我们每次都跑到元区间其实是很不划算的,因为我们在查询的时候并不是每次都查到了元区间。也就是说我们只需要将我们需要的点,即会对答案产生价值的点进行精确更改即可。

这个很显然吧,一个小贪心。

那么我们可以将每个节点保存的信息多加一个:(add)。这个 (add) 表示,之前区间修改时没累加在当前节点但其实需要去累加的价值。

也就是说我们需要将所有之前的操作更改的价值累加起来,在我们需要查询 (t[p imes 2]),我们再由标记在 (t[p]) 上的 (add) 去更新 (t[p imes 2]),求得实际的值。(t[p imes 2 + 1]) 同理。

好像很抽象?我还是扒个图吧。

这里面的绿点表示修改区间 ([3, 9]) 时本来会改变的线段树上的节点。

而我们每次 lazy 只标记黄点。

在我们下次要去查询某个绿点时,我们再由黄点的 lazy tag 去更新绿点的信息。

也就是说真实的绿点 (t[q]) 满足:(t[q].Sum = t[q].Sum + t[p].add * (t[q].r - t[q].l + 1))。其中 (t[p])(t[q]) 的父亲节点。而 (t[p].add) 需要乘上它表示的节点个数,因为我们存的是单个节点的 lazy tag。

啊,没有智商了,那就结合代码再分析吧。

#include <cstdio>

typedef long long LL;
const int MAXN = 1e6 + 5;
const int MAXT = 1e6 * 4 + 5;
struct Segment_Tree {
	int l, r, len;
	LL Sum, add; // lazy tag
	Segment_Tree() {}
	Segment_Tree(int L, int R, LL S, LL A, int Len) {
		l = L;
		r = R;
		Sum = S;
		add = A;
		len = Len;
	}
} t[MAXT];
int a[MAXN];

void Spread(int p) { 
	// 从父亲往儿子更新标记
	if(t[p].add) {
		t[p << 1].Sum += t[p].add * t[p << 1].len;
		t[p << 1 | 1].Sum += t[p].add * t[p << 1 | 1].len;
		t[p << 1].add += t[p].add;
		t[p << 1 | 1].add += t[p].add;
		t[p].add = 0; 
	}
}

void Make_Tree(int p, int l, int r) {
	t[p].l = l;
	t[p].r = r;
	t[p].len = r - l + 1;
	if(l == r) {
		t[p].Sum = a[l];
		return ;
	}
	int mid = (l + r) >> 1;
	Make_Tree(p << 1, l, mid);
	Make_Tree(p << 1 | 1, mid + 1, r);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum; 
}

void Update(int p, int l, int r, int x) {
	if(l <= t[p].l && t[p].r <= r) {
		t[p].Sum += (LL)x * t[p].len;
		t[p].add += x;
		return ;
	}
	Spread(p);
	// 更新标记。
	int mid = (t[p].l + t[p].r) >> 1;
	if(l <= mid)
		Update(p << 1, l, r, x);
	if(r > mid)
		Update(p << 1 | 1, l, r, x);
	t[p].Sum = t[p << 1].Sum + t[p << 1 | 1].Sum;
}

LL Query(int p, int l, int r) {
	if(l <= t[p].l && t[p].r <= r) 
		return t[p].Sum;
	Spread(p);
	// 更新标记。
	int mid = (t[p].l + t[p].r) >> 1;
	LL val = 0;
	if(l <= mid)
		val += Query(p << 1, l, r);
	if(r > mid)
		val += Query(p << 1 | 1, l, r);	
	return val;
}

int main() {
	int n, q;
	scanf ("%d %d", &n, &q);
	for(int i = 1; i <= n; i++)
		scanf ("%d", &a[i]);
	Make_Tree(1, 1, n);
	for(int i = 1; i <= q; i++) {
		int flag;
		scanf ("%d", &flag);
		if(flag == 1) {
			int l, r, x;
			scanf ("%d %d %d", &l, &r, &x);
			Update(1, l, r, x);
		}
		else {
			int l, r;
			scanf ("%d %d", &l, &r);
			printf("%lld
", Query(1, l, r));
		}
	}
	return 0;
}

0x04 例题

题目描述

给定一个长度为 (n) 的序列 (a),以及 (q) 次操作,操作有两类。

  • 1: 对于所有 (i in [l, r]),将 (a[i]) 加上 (x)(换言之,将 (a_l, a_{l + 1} ,..., a_r) 分别加上 (x) );
  • 2: 给定 (l, r),求 (gcd(a_l, a_{l + 1} ,..., a_r)) 的值。

输入格式

第一行包含 (2) 个正整数 (n, q),表示数列长度和询问个数。保证 (1 leq n, q leq 5 imes 10^5)

第二行 (n) 个整数 (a_1, a_2,...,a_n),表示初始数列。保证 (|a_i| leq 2^{63} - 1)

接下来 (q) 行,每行一个操作,为以下两种之一:

  • 1 l r x: 对于所有 (i in [l, r]),将 (a[i]) 加上 (x)
  • 2 l r: 输出 (gcd(a_l, a_{l + 1} ,..., a_r)) 的值。

保证 (1 leq l leq r leq n, |x| leq 2^{63} - 1)

输出格式

对于每个 2 l r 操作,输出一行,每行有一个整数,表示所求的结果。

样例输入

5
12 15 20 24 36
3
2 1 2
1 4 5 4
2 3 5

样例输出

3
4

分析

你会发现这个求数列的 (gcd) 的操作非常的迷。

但是我们能大概看出来这道题需要区间修改,单点查询对吧。而区间修改,单点查询可以直接差分维护解决。

于是我们考虑如何将所求答案向差分转化。

尝试去证:对于序列 (a),以及其差分序列 (b) 有:(gcd(a_1, a_2,...a_n) = gcd(b_1, b_2,...,b_n))

我们设 (q = gcd(a_1, a_2,...,a_n))

则序列 (a) 可以写为:(q imes k_1, q imes k_2,..., q imes k_n)

显然序列 (b) 可以写为:(q imes (k_1 - k_0), q imes (k_2 - k_1),..., q imes (k_n - k_{n - 1}))

可以看出 (q) 一定是 (gcd(b_1, b_2,...,b_n)) 的一个因数。

假设 (q eq gcd(b_1, b_2,...,b_n)),此时不妨设 (gcd(b_1, b_2,...,b_n) = q imes p(p eq 1))

则序列 (b) 还可以写为 (q imes p imes (k_1 - k_0), q imes p imes (k_2 - k_1),..., q imes p imes (k_n - k_{n - 1}))

总所周知,(sum_{i = 1}^{n}b_i = a_i)

所以 (a_i = q imes p imes sum_{j = 1}^{n}(k_j - k_{j - 1})(i in [1, n]))

(p imes q)(gcd(a_1, a_2,...,a_n)) 的一个因数。

(∵p eq 1)(q = gcd(a_1, a_2,...,a_n))。推出矛盾。

(q = gcd(b_1, b_2,...,b_n) = gcd(a_1, a_2,...,a_n))

那么这道题不就结了嘛。

我们跑线段树,维护一个差分序列的区间 (gcd)。(很显然 (gcd) 可以由左右两颗子树的 (gcd) 合并而成。

利用差分数组很简单的一个性质,如果要对原数组的区间 ([i, j](i leq j leq n)) 增加 (x),那么就是将差分数组的第 (i) 个元素增加 (x),将第 (j + 1) 个元素减去 (x)即可。

那么我们的答案就很容易可以求到了。

不过需要注意的是:

1.我们相当于是把一段区间的 (gcd) 转换为了其差分数组的 (gcd),除了第一位外,该差分数组的元素都与原数组的差分数组相同,而第一位 (x) 就是对应的原数组的元素,而我们有求道了原数组的差分数组,那么直接树状数组维护原数组的值即可。

2.当你在进行修改操作时,可能会出现 (r = n + 1),((r) 为需修改区间的右端点),而若你的线段树是建的 (1)(n) 的节点,那么你在更改 (r + 1) 时,它会作用在节点 (n) 上,因为线段树是想要改的点在哪里就往当前方向走而没有判断是否走到,也就是说可能走不到,那就只能在走到的最远的地方更改咯。对于这个问题,我们只需要在更新的时候特判一些,或者线段树 (1)(n + 1) 建树即可。

AC代码

#include <cstdio>

typedef long long LL;
LL read_LL() {
    int k = 1;
    LL x = 0;
    char s = getchar();
    while (s < '0' || s > '9') {
        if (s == '-')
            k = -1;
        s = getchar();
    }
    while (s >= '0' && s <= '9') {
        x = (x << 3) + (x << 1) + s - '0';
        s = getchar();
    }
    return x * k;
}
int read_INT() {
    int k = 1;
    int x = 0;
    char s = getchar();
    while (s < '0' || s > '9') {
        if (s == '-')
            k = -1;
        s = getchar();
    }
    while (s >= '0' && s <= '9') {
        x = (x << 3) + (x << 1) + s - '0';
        s = getchar();
    }
    return x * k;
}
void write(LL x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
void print(LL x, char s) {
    write(x);
    putchar(s);
}
inline LL Abs(LL x) { return x < 0 ? -x : x; }
LL gcd(LL x, LL y) {
    if (!y)
        return x;
    return gcd(y, x % y);
}
const int MAXN = 5 * 1e5 + 5;
struct Segment_Tree {
    int l, r;
    LL res;
    Segment_Tree() {}
    Segment_Tree(int L, int R, LL Res) {
        l = L;
        r = R;
        res = Res;
    }
} t[MAXN * 4];
LL pre[MAXN], BIT[MAXN];
int n;

void Make_Tree(int p, int l, int r) {
    t[p].l = l;
    t[p].r = r;
    if (l == r) {
        t[p].res = pre[l];
        return;
    }
    int mid = (l + r) >> 1;
    Make_Tree(p << 1, l, mid);
    Make_Tree(p << 1 | 1, mid + 1, r);
    t[p].res = gcd(Abs(t[p << 1].res), Abs(t[p << 1 | 1].res));
}

void Update_Seg(int p, int index, LL x) {
    if (index == n + 1)
        return;
    if (t[p].l == t[p].r) {
        t[p].res += x;
        return;
    }
    int mid = (t[p].l + t[p].r) >> 1;
    if (index <= mid)
        Update_Seg(p << 1, index, x);
    else
        Update_Seg(p << 1 | 1, index, x);
    t[p].res = gcd(Abs(t[p << 1].res), Abs(t[p << 1 | 1].res));
}

LL Query_Seg(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p].res;
    int mid = (t[p].l + t[p].r) >> 1;
    LL val = 0;
    if (l <= mid)
        val = gcd(val, Abs(Query_Seg(p << 1, l, r)));
    if (r > mid)
        val = gcd(val, Abs(Query_Seg(p << 1 | 1, l, r)));
    return val;
}

int Low_Bit(int x) { return x & (-x); }

void Update_BIT(int k, LL x) {
    for (int i = k; i <= n; i += Low_Bit(i)) BIT[i] += x;
    return;
}

LL Query_BIT(int k) {
    LL val = 0;
    for (int i = k; i >= 1; i -= Low_Bit(i)) val += BIT[i];
    return val;
}

int main() {
    n = read_INT();
    LL last = 0;
    for (int i = 1; i <= n; i++) {
        LL x = read_LL();
        pre[i] = x - last;
        Update_BIT(i, pre[i]);
        last = x;
    }
    Make_Tree(1, 1, n);
    int q = read_INT();
    for (int i = 1; i <= q; i++) {
        int flag = read_INT();
        if (flag == 1) {
            int l = read_INT(), r = read_INT();
            LL x = read_LL();
            Update_Seg(1, l, x);
            Update_Seg(1, r + 1, -x);
            Update_BIT(l, x);
            Update_BIT(r + 1, -x);
        } else if (flag == 2) {
            int l = read_INT(), r = read_INT();
            print(gcd(Abs(Query_Seg(1, l + 1, r)), Query_BIT(l)), '
');
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Chain-Forward-Star/p/14016463.html