树链剖分学习笔记

简介

树链剖分,顾名思义,就是把树剖分成链,在链上进行一系列的操作。

下面我们就来学习一下这个算法。

概念

树链剖分引入了很多新的概念:

  1. 重儿子:一个节点所有的儿子中子树(size)最大的儿子。
  2. 轻儿子:一个节点的儿子中除了重儿子都是轻儿子。
  3. 重边:一个节点与它的重儿子所组成的边。
  4. 轻边:一个节点与它的轻儿子组成的边。
  5. 重链:若干条重边组成的链。
  6. 轻链:若干条轻边组成的链。

思想

树链剖分经常与线段树相结合进行链上的操作。

因此线段树是必须要掌握的。

树链剖分一开始要进行(2)(dfs)

第一次(dfs)需要记录出一个节点的父亲、节点的深度和节点的重儿子。

第二次(dfs)需要对每个节点进行重新标号,按照重儿子优先的顺序遍历;还要记录出节点所在链的顶端;以及当前标号的点的编号。

然后就是线段树的基本操作。

对链进行维护时需要将两端点往上跳,直到它们在同一条剖分好的链上。

代码

这里以ZJOI2008 树的统计为例题讲解一下树链剖分的代码。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
#include <string>
#define itn int
#define gI gi

using namespace std;

inline int gi()
{
	int f = 1, x = 0; char c = getchar();
	while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
	while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
	return f * x;
}

int q, n, m;
int tot, head[100003], nxt[100003], ver[100003];
int dfn[100003], dep[100003], fa[100003];
int top[100003], son[100003], sz[100003];
int pre[100003], tim;
int a[100003];

inline void add(int u, int v)//邻接表存图
{
	ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
}

void dfs1(itn u, int f)//第一次dfs
{
	fa[u] = f/*记录父亲*/, sz[u] = 1/*记录子树大小*/, dep[u] = dep[f] + 1;/*标记深度*/ 
	int maxsize = -1;//最大子树大小
	for (itn i = head[u]; i; i = nxt[i])//遍历子节点
	{
		int v = ver[i];
		if (v == f) continue;
		dfs1(v, u);
		sz[u] = sz[u] + sz[v];//计算子树大小
		if (sz[v] > maxsize)//当前子树大小超过当前最大的子树大小
		{
			maxsize = sz[v], son[u] = v;//更新最大子树大小并标记重儿子
		} 
	}
}

void dfs2(int u, int f)
{
	dfn[u] = ++tim/*将树重新标号*/, top[u] = f/*记录链顶*/, pre[tim] = u/*重新编号后编号为tim的节点编号*/;
	if (son[u]) dfs2(son[u], f);//优先遍历重儿子
	for (itn i = head[u]; i; i = nxt[i])
	{
		int v = ver[i];
		if (v == son[u] || v == fa[u]) continue;//处理过就不需要再处理了
		dfs2(v, v);//找出下一条链
	}
}

/******以下为线段树******/

int sum[400003], maxs[400003];

inline int ls(int u) {return u << 1;}//左儿子
inline int rs(int u) {return (u << 1) | 1;}//右儿子

inline void pushup(int p)//上传标记
{
	sum[p] = sum[ls(p)] + sum[rs(p)];//区间和
	maxs[p] = max(maxs[ls(p)], maxs[rs(p)]);//区间最大值
}

void build(int l, int r, itn p)//建树
{
	if (l == r) {sum[p] = maxs[p] = a[pre[l]];/*注意是pre[l]*/ return;}//子节点
	int mid = (l + r) >> 1;
	build(l, mid, ls(p)); build(mid + 1, r, rs(p));
	pushup(p);//上传节点
}

void update(int x, int y, itn l, int r, int p)//更新节点信息
{
	if (l == r) {sum[p] = maxs[p] = y; return;}//找到了要更新的节点
	int mid = (l + r) >> 1;
	if (x <= mid) update(x, y, l, mid, ls(p));//左区间寻找
	else update(x, y, mid + 1, r, rs(p));//右区间寻找
	pushup(p);//上传节点
}

itn getmax(int ql, int qr, int l, itn r, int p)//区间最大值查找
{
	if (ql <= l && r <= qr) return maxs[p];//当前区间包含于要寻找的区间
	itn mid = (l + r) >> 1, ans = -1000000000;
	if (ql <= mid) ans = max(ans, getmax(ql, qr, l, mid, ls(p)));//向左寻找最大值
	if (qr > mid) ans = max(ans, getmax(ql, qr, mid + 1, r, rs(p)));//向右寻找最大值
	pushup(p);//上传节点
	return ans;//返回答案
}

itn getans(int ql, int qr, int l, itn r, int p)//区间和查找,与区间最大值查找没有什么区别
{
	if (ql <= l && r <= qr) return sum[p];
	itn mid = (l + r) >> 1, ans = 0;
	if (ql <= mid) ans = ans + getans(ql, qr, l, mid, ls(p));
	if (qr > mid) ans = ans + getans(ql, qr, mid + 1, r, rs(p));
	pushup(p);
	return ans;
}

/******以上为线段树******/

inline int qmax(int l, itn r)//查找路径上最大值
{
	itn ans = -1000000000;
	while (top[l] != top[r])//不在同一条链上
	{
		if (dep[top[l]] < dep[top[r]]) swap(l, r);//找链顶深度大的节点
		ans = max(ans, getmax(dfn[top[l]], dfn[l], 1, n, 1));//更新最大值
		l = fa[top[l]];//跳到当前链顶的父亲
	}
	if (dep[l] > dep[r]) swap(l, r);//要满足左端点深度小
	ans = max(ans, getmax(dfn[l], dfn[r], 1, n, 1));//更新答案
	return ans;//返回
}

inline int qsum(int l, itn r)//求路径权值和,与查找最大值同理
{
	itn ans = 0;
	while (top[l] != top[r])
	{
		if (dep[top[l]] < dep[top[r]]) swap(l, r);
		ans = ans + getans(dfn[top[l]], dfn[l], 1, n, 1);
		l = fa[top[l]];
	}
	if (dep[l] < dep[r]) swap(l, r);
	ans = ans + getans(dfn[r], dfn[l], 1, n, 1);
	return ans;
}

int main()
{
	n = gi();
	for (int i = 1; i < n; i+=1) 
	{
		int u = gI(), v = gI(); 
		add(u, v), add(v, u);
	}
	for (int i = 1; i <= n; i+=1) a[i] = gi();
	dfs1(1, -1); dfs2(1, 1); build(1, n, 1);//预处理
	q = gi();
	while (q--)
	{
		char s[10];
		scanf("%s", s);
		int u = gi(), v = gi();
		if (s[1] == 'M') printf("%d
", qmax(u, v));//区间最大值查找
		else if (s[1] == 'S') printf("%d
", qsum(u, v));//求区间和
		else update(dfn[u], v, 1, n, 1);//更新节点
	}
	return 0;
}

应用

树链剖分求( exttt{LCA})

代码如下(以洛谷模板为例):

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
#define itn int
#define gI gi

using namespace std;

inline int gi()
{
	int f = 1, x = 0; char c = getchar();
	while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
	while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
	return f * x;
}

int n, m, rt, dfn[500003], dep[500003], fa[500003], sz[500003], son[500003], pre[500003], top[500003];
int tot, head[2000003], nxt[2000003], ver[2000003];

inline void add(itn u, int v)
{
	ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
}

void dfs1(int u, int f)
{
	fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
	for (itn i = head[u]; i; i = nxt[i])
	{
		int v = ver[i];
		if (v == f) continue;
		dfs1(v, u);
		sz[u] = sz[u] + sz[v];
		if (sz[v] > sz[son[u]]) son[u] = v;
	}
}

int tim;

void dfs2(itn u, int f)
{
	top[u] = f, dfn[u] = ++tim, pre[tim] = u;
	if (!son[u]) return;
	dfs2(son[u], f);
	for (int i = head[u]; i; i = nxt[i])
	{
		int v = ver[i];
		if (v == fa[u] || v == son[u]) continue;
		dfs2(v, v);
	}
}

int main()
{
	n = gi(), m = gi(), rt = gi();
	for (itn i = 1; i < n; i+=1)
	{
		int u = gi(), v = gi();
		add(u, v), add(v, u);
	}
	dfs1(rt, rt); 
	dfs2(rt, rt);
	while (m--)
	{
		int u = gi(), v = gi();
		while (top[u] != top[v])
		{
			if (dep[top[u]] < dep[top[v]]) swap(u, v);
			u = fa[top[u]];
		}
		if (dep[u] < dep[v]) printf("%lld
", u);
		else printf("%lld
", v);
	}
	return 0;
}

总结

理解一个算法的思想很重要。

代码要熟练地打出来才算真正理解。

记录一下我踩过的坑:

  • 建树时把(pre[l])写成了(l)

  • 跳端点时没有注意左端点编号小于右端点编号;

  • 子树(size)初始化成(0)

  • ( exttt{LCA})时把<写成>

就这样吧~

原文地址:https://www.cnblogs.com/xsl19/p/shupou.html