树链剖分学习笔记(一)

0. 简介

树链剖分用于处理树上问题(废话)
思想是把一棵树分成若干条链,而链上的操作显然比树上好做,这样就降低了处理难度

有多种方式能将树拆成链,因此树链剖分也有不同的种类
这篇文章写的是其中一种:轻重链剖分(最常见,有时也被直接称为树链剖分)

1. 概念

定义 \(size_x\) 表示以 \(x\) 为根的子树大小(即包含的节点个数)

\(x\) 的所有儿子中 \(size\) 值最大的那个为 \(x\)重儿子,记作 \(son_x\)
(如果有多个儿子满足条件,随便选其中一个即可)

对于每个节点 \(x\) ,将连接 \(x\)\(son_x\) 的边定为重边,其它的定为轻边

将全部由重边组成的路径称为重链

这样树就被分成了若干条重链和若干条轻边

分完之后我们发现了一些有用的性质:

  1. \(y\)\(x\) 的儿子,但不是重儿子,则 \(size_y\le size_x/2.\)
    反证即可,假设 \(size_y>size_x/2\) ,则 \(size_y\) 比其它儿子的 \(size\) 值之和还大,故 \(y\) 是重儿子,与条件矛盾
    所以 \(size_y\le size_x/2\)

  2. 从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条轻边。
    由性质 1 可知,通过一条轻边往下走,子树大小至少减半,故显然性质成立

  3. 从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条重链。
    重链之间是用轻边分隔的(废话)
    因此重链的条数和轻边一样也是 \(O(\log n)\) 级别

题目常会让我们对两点间路径上的所有点执行某操作
由上面的性质易知,这条路径可被分成不超过 \(O(\log n)\) 条重链和轻边
我们希望能快速处理重链上的操作

对整棵树跑一遍深度优先遍历,并且优先遍历重儿子
则同一条重链对应的 dfs 序必然是连续的一段
然后重链上的操作就转化成了序列问题,用合适的数据结构维护即可

下面通过一道模板题讲一下具体怎么使用

2. 实现

P3384 【模板】轻重链剖分/树链剖分

首先预处理出必要的信息

/**
 * ​ fa[x]: 节点 x 的爹
 * dep[x]: 节点 x 的深度
 *  sz[x]: 节点 x 的子树大小,即上文中的 size
 * son[x]: 节点 x 的重儿子
 * top[x]: 节点 x 所在重链的顶端的节点(深度最小)
 * ord[x]: 节点 x 的 dfs 序
 * rev[n]: dfs 序为 n 的节点,即 rev 是 ord 的逆映射
**/
void dfs(int x,int f) { // 第一遍 dfs ,处理 fa,dep,sz,son
   ​fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
   ​for (int i=last[x]; i; i=E[i].pre) {
       ​int y=E[i].y;
       ​if (y==f) continue;
       ​dfs(y,x);
       ​sz[x]+=sz[y];
       ​if (sz[son[x]]<sz[y]) son[x]=y;
   ​}
}
void dfs2(int x); // 第二遍 dfs ,处理 top,ord,rev
void work(int x,int topx) {
   ​top[x]=topx;
   ​++n2,ord[x]=n2,rev[n2]=x;
   ​dfs2(x);
}
void dfs2(int x) {
   ​if (!son[x]) return ;
   ​work(son[x],top[x]); // 优先遍历重儿子
   ​for (int i=last[x]; i; i=E[i].pre)
       ​if (!top[E[i].y]) work(E[i].y,E[i].y);
}

操作 1,2 是关于两节点 \(x,y\) 之间路径的
先给出操作 1 的代码实现:

void solve1(int x,int y,int v) {
    int tx=top[x],ty=top[y];
    while (tx!=ty) {
        if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
        update(ord[tx],ord[x],v,1,n,1); // 修改
        x=fa[tx],tx=top[x]; // 跳到下一条重链底端 
    }
    if (dep[x]>dep[y]) swap(x,y);
    update(ord[x],ord[y],v,1,n,1);
}

其实就是一直让深度较大的那个点往上跳
每次跳过一整条重链,并修改这条链上的信息
两点位于同一条重链上时停止,此时剩下的一小段路径都在这条链上,直接修改即可

操作 2 也没啥区别,把修改换成查询,最后返回总和即可

ll solve2(int x,int y) {
    int tx=top[x],ty=top[y]; ll ans=0;
    while (tx!=ty) {
        if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
        ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
        x=fa[tx],tx=top[x];
    }
    if (dep[x]>dep[y]) swap(x,y);
    return (ans+query(ord[x],ord[y],1,n,1))%P;
}

操作 3,4 和子树相关,所以一次操作涉及的所有节点的 dfs 序是连续的一段
直接在对应区间上做修改/查询即可

void solve3(int x,int v) {
    update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
}
ll solve4(int x) {
    return query(ord[x],ord[x]+sz[x]-1,1,n,1);
}

其中 updatequery 的写法取决于你用什么数据结构

顺便讲一下树状数组怎么做区间修改和区间查询

对于区间修改,差分一波
\(c_i=a_i-a_{i-1}\) (规定 \(a_0=0\)
若要将 \(a_l\sim a_r\) 全部加 \(x\) ,则 \(\Delta c_l=x,\Delta c_{r+1}=-x\)

对于区间查询,只需考虑如何求前缀和 \(S(x)=\sum\limits_{i=1}^x a_i\)

\(\begin{aligned} S(x)&=\sum\limits_{i=1}^x a_i\\ &=\sum\limits_{i=1}^x\sum\limits_{j=1}^ic_j\\ &=\sum\limits_{j=1}^x(x+1-j)c_j\\ &=(x+1)\sum\limits_{i=1}^xc_i-\sum\limits_{i=1}^xic_i \end{aligned}\)

\(d_i=ic_i\) ,则 \(S(x)\) 可以用 \(c\)\(d\) 的前缀和表示
用树状数组维护 \(c\)\(d\) 即可

因此本题中树状数组和线段树均可使用,时间复杂度都是 \(O(n+m\log ^2n)\) (跑不满)

3. 完整代码

线段树(使用懒标记)版本
线段树(标记永久化)和树状数组的写得太丑了 也懒得重写 就不放了(

P3384 SGT ver.
#include<stdio.h>
#include<ctype.h>
#define Tl T[p<<1]
#define Tr T[p<<1|1]
typedef long long ll;
const int N=100010;
int n,m,root,P,n2,cnt,opt,x,y,v,value[N],last[N];
int fa[N],dep[N],sz[N],son[N],top[N],ord[N],rev[N];
struct edge { int y,pre; }E[N<<1];
inline void swap(int &x,int &y) { int t=x; x=y,y=t; }
void read(int &x);

struct SGT { int len; ll tag,sum; }T[N<<2];
void build(int l,int r,int p);
void pushup(int p);
void pushdown(int p);
void update(int x,int y,int v,int l,int r,int p);
ll query(int x,int y,int l,int r,int p);

void dfs(int x,int f) {
    fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
    for (int i=last[x]; i; i=E[i].pre) {
        int y=E[i].y;
        if (y==f) continue;
        dfs(y,x);
        sz[x]+=sz[y];
        if (sz[son[x]]<sz[y]) son[x]=y;
    }
}
void dfs2(int x);
void work(int x,int topx) {
    top[x]=topx;
    ++n2,ord[x]=n2,rev[n2]=x;
    dfs2(x);
}
void dfs2(int x) {
    if (!son[x]) return ;
    work(son[x],top[x]);
    for (int i=last[x]; i; i=E[i].pre)
        if (!top[E[i].y]) work(E[i].y,E[i].y);
}

void solve1(int x,int y,int v) {
    int tx=top[x],ty=top[y];
    while (tx!=ty) {
        if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
        update(ord[tx],ord[x],v,1,n,1);
        x=fa[tx],tx=top[x];
    }
    if (dep[x]>dep[y]) swap(x,y);
    update(ord[x],ord[y],v,1,n,1);
}
ll solve2(int x,int y) {
    int tx=top[x],ty=top[y]; ll ans=0;
    while (tx!=ty) {
        if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
        ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
        x=fa[tx],tx=top[x];
    }
    if (dep[x]>dep[y]) swap(x,y);
    return (ans+query(ord[x],ord[y],1,n,1))%P;
}
void solve3(int x,int v) {
    update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
}
ll solve4(int x) {
    return query(ord[x],ord[x]+sz[x]-1,1,n,1);
}

int main() {
    read(n),read(m),read(root),read(P);
    for (int i=1; i<=n; ++i) read(value[i]);
    for (int i=1; i<n; ++i) {
        read(x),read(y);
        E[++cnt]={y,last[x]},last[x]=cnt;
        E[++cnt]={x,last[y]},last[y]=cnt;
    }
    dfs(root,0);
    work(root,root);
    build(1,n,1);
    while (m--) {
        read(opt),read(x);
        if (opt==1) read(y),read(v),solve1(x,y,v);
        if (opt==2) read(y),printf("%lld\n",solve2(x,y));
        if (opt==3) read(v),solve3(x,v);
        if (opt==4) printf("%lld\n",solve4(x));
    }
    return 0;
}

void read(int &x) {
    x=0; char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    while (isdigit(ch)) x=x*10+(ch^48),ch=getchar();
}
void build(int l,int r,int p) {
    T[p].len=r-l+1;
    if (l==r) {
        T[p].sum=value[rev[l]];
        return ;
    }
    int mid=(l+r>>1);
    build(l,mid,p<<1);
    build(mid+1,r,p<<1|1);
    pushup(p);
}
inline void pushup(int p) {
    T[p].sum=(Tl.sum+Tr.sum)%P;
}
inline void pushdown(int p) {
    int t=T[p].tag;
    Tl.tag=(Tl.tag+t)%P;
    Tl.sum=(Tl.sum+t*Tl.len)%P;
    Tr.tag=(Tr.tag+t)%P;
    Tr.sum=(Tr.sum+t*Tr.len)%P;
    T[p].tag=0;
}
void update(int x,int y,int v,int l,int r,int p) {
    if (x<=l&&y>=r) {
        T[p].tag=(T[p].tag+v)%P;
        T[p].sum=(T[p].sum+1ll*v*T[p].len)%P;
        return ;
    }
    pushdown(p);
    int mid=(l+r>>1);
    if (x<=mid) update(x,y,v,l,mid,p<<1);
    if (y>mid) update(x,y,v,mid+1,r,p<<1|1);
    pushup(p);
}
ll query(int x,int y,int l,int r,int p) {
    if (x<=l&&y>=r) return T[p].sum;
    pushdown(p);
    int mid=(l+r>>1); ll ans=0;
    if (x<=mid) ans=(ans+query(x,y,l,mid,p<<1))%P;
    if (y>mid) ans=(ans+query(x,y,mid+1,r,p<<1|1))%P;
    pushup(p);
    return ans;
}
原文地址:https://www.cnblogs.com/REKonib/p/15550441.html