树链剖分入门小结

 最近在学树链剖分,写篇文章记录一下。

原理及代码:https://blog.csdn.net/cdy1206473601/article/details/79189553#_11

题目:https://www.cnblogs.com/hanruyun/p/9577500.html

树链剖分说白了就是把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链。因为通常的数据结构处理区间信息很容易,但处理树上的信息就显得捉襟见肘了。于是我们想到把树拍成一个区间用线段树去维护信息。(和树的dfs序是类似的原理)。

 树链剖分的几个常见应用:

①查询/修改树的子树的值:因为dfs的遍历顺序关系,每颗子树在线段树上必定是一段连续的区间,所以容易处理。

②查询/修改树两点路径间的值:因为每条重链是一段连续区间,那么两点间路径肯定是若干条链合起来得到,也就是线段树上的几段区间。

模板题:洛谷P3384

#pragma comment(linker,"/STACK:102400000,102400000")
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int n,m,rt,MOD,num=0,v[N];
struct node{
    int dep,fa,val,sz,heavy;
    int toseg,top; 
}tree[N];  //原本的树 
int totree[N<<2];  //线段树点i代表原树的点totree[i] 

/*-------------------------以下为线段树-----------------------------*/
int sum[N<<2],tag[N<<2];
void pushdown(int rt,int len1,int len2) {
    sum[rt<<1]+=tag[rt]*len1%MOD; sum[rt<<1]%=MOD; 
    tag[rt<<1]+=tag[rt]%MOD; tag[rt<<1]%=MOD;
    sum[rt<<1|1]+=tag[rt]*len2%MOD; sum[rt<<1|1]%=MOD;
    tag[rt<<1|1]+=tag[rt]%MOD; tag[rt<<1|1]%=MOD;
    tag[rt]=0;
}
void pushup(int rt) {
    sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%MOD;
}

void build(int rt,int l,int r) {
    if (l==r) {
        sum[rt]=tag[rt]=tree[totree[l]].val%MOD;
        return;
    }
    int mid=l+r>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
}

void update(int rt,int l,int r,int ql,int qr,int v) {
    if (ql<=l && r<=qr) {
        sum[rt]+=v*(r-l+1)%MOD; sum[rt]%=MOD;
        tag[rt]+=v%MOD; tag[rt]%=MOD;
        return;
    }
    int mid=l+r>>1;
    pushdown(rt,mid-l+1,r-mid);
    if (ql<=mid) update(rt<<1,l,mid,ql,qr,v);
    if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,v);
    pushup(rt);
}

int query(int rt,int l,int r,int ql,int qr) {
    if (ql<=l && r<=qr) return sum[rt];
    int mid=l+r>>1,ret=0;
    pushdown(rt,mid-l+1,r-mid);
    if (ql<=mid) ret=(ret+query(rt<<1,l,mid,ql,qr))%MOD;
    if (qr>mid) ret=(ret+query(rt<<1|1,mid+1,r,ql,qr))%MOD;
    return ret;
}

/*--------------------------以下为树链剖分----------------------------*/
int cnt=1,head[N<<1],to[N<<1],nxt[N<<1];
void add_edge(int x,int y) {
    nxt[++cnt]=head[x]; to[cnt]=y; head[x]=cnt;
}

void dfs1(int x,int fa,int dep) {  //点x的父亲为fa深度为dep 
    tree[x].dep=dep;
    tree[x].fa=fa;
    tree[x].sz=1;
    tree[x].val=v[x];
    int maxson=-1;
    for (int i=head[x];i;i=nxt[i]) {
        int y=to[i];
        if (y==fa) continue;
        dfs1(y,x,dep+1);
        tree[x].sz+=tree[y].sz;
        if (tree[y].sz>maxson) tree[x].heavy=y,maxson=tree[y].sz;
    }
}

void dfs2(int x,int top) {  //点x所在树链的top 
    tree[x].toseg=++num;
    tree[x].top=top;
    totree[num]=x;
    if (!tree[x].heavy) return;  //叶子结点 
    dfs2(tree[x].heavy,top);  //先剖分重儿子 
    for (int i=head[x];i;i=nxt[i]) {  //再剖分轻儿子 
        int y=to[i];
        if (y==tree[x].fa || y==tree[x].heavy) continue;
        dfs2(y,y);
    }
}

//以下两个函数是树链剖分的精髓 
void update2(int x,int y,int z) {  //修改x到y路径的值 
    while (tree[x].top!=tree[y].top) {  //不在同一条链上 
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);  //x为深度大的链 
        update(1,1,n,tree[tree[x].top].toseg,tree[x].toseg,z);  //x向上跳的同时更新 
        x=tree[tree[x].top].fa;  //深度大的向上跳 
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);  //这里x和y在同一条链 
    update(1,1,n,tree[x].toseg,tree[y].toseg,z);  //x和y这条链的更新 
}

int query2(int x,int y) {  //查询x到y路径的值,原理同上 
    int ret=0;
    while (tree[x].top!=tree[y].top) {
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);
        ret=(ret+query(1,1,n,tree[tree[x].top].toseg,tree[x].toseg))%MOD;
        x=tree[tree[x].top].fa;
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);
    ret=(ret+query(1,1,n,tree[x].toseg,tree[y].toseg))%MOD;
    return ret;
}

int main()
{
    cin>>n>>m>>rt>>MOD;
    for (int i=1;i<=n;i++) scanf("%d",&v[i]);
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        add_edge(x,y); add_edge(y,x);
    }
    dfs1(rt,0,1);  //树链剖分准备信息 
    dfs2(rt,rt);  //开始树链剖分 
    build(1,1,n);  //把树链建成线段树 
    for (int i=1;i<=m;i++) {
        int opt,x,y,z; scanf("%d",&opt);
        if (opt==1) {  //修改x到y路径的点的值 
            scanf("%d%d%d",&x,&y,&z);
            update2(x,y,z%MOD);
        }
        if (opt==2) {  //查询x到y路径的值 
            scanf("%d%d",&x,&y);
            printf("%d
",query2(x,y));
        }
        if (opt==3) {  //修改x子树的值 
            scanf("%d%d",&x,&z);
            update(1,1,n,tree[x].toseg,tree[x].toseg+tree[x].sz-1,z%MOD);
        }
        if (opt==4) {  //查询x子树的值 
            scanf("%d",&x);
            printf("%d
",query(1,1,n,tree[x].toseg,tree[x].toseg+tree[x].sz-1));
        }
    }
    return 0;
}

比较简单的树链剖分题目其实就是上面说的这两种应用或其变式,都是对树链剖分本身基本没有变化,都是在维护树链的数据结构做变化 (例如:线段树等等)。这里还有几道练习题:

洛谷P4116 Qtree3

树链剖分后,线段树维护区间黑点的深度最小点(即区间黑点中深度最小的一个),那么第一个操作就是单点更新,第二个操作就是区间查询而已。变成了简单的线段树操作。

#pragma comment(linker,"/STACK:102400000,102400000")
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
const int INF=0x3f3f3f3f;
typedef long long LL;
int n,m,rt,num=0,v[N];
struct node{
    int dep,fa,val,sz,heavy;
    int toseg,top; 
}tree[N];  //原本的树 
int totree[N<<2];  //线段树点i代表原树的点totree[i] 

/*-------------------------以下为线段树-----------------------------*/
LL Min[N<<2];
void pushup(int rt) {
    if (tree[Min[rt<<1]].dep<tree[Min[rt<<1|1]].dep) Min[rt]=Min[rt<<1];
    else Min[rt]=Min[rt<<1|1];
}

void build(int rt,int l,int r) {
    Min[rt]=0;
    if (l==r) return;
    int mid=l+r>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
}

void update(int rt,int l,int r,int ql,int qr,int v) {
    if (ql<=l && r<=qr) {
        Min[rt]=v;
        return;
    }
    int mid=l+r>>1;
    if (ql<=mid) update(rt<<1,l,mid,ql,qr,v);
    if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,v);
    pushup(rt);
}

int query(int rt,int l,int r,int ql,int qr) {
    if (ql<=l && r<=qr) return Min[rt];
    int mid=l+r>>1,ret=0;
    if (ql<=mid) {
        int t=query(rt<<1,l,mid,ql,qr);
        if (tree[ret].dep>tree[t].dep) ret=t;
    }
    if (qr>mid) {
        int t=query(rt<<1|1,mid+1,r,ql,qr);
        if (tree[ret].dep>tree[t].dep) ret=t;
    }
    return ret;
}

/*--------------------------以下为树链剖分----------------------------*/
int cnt=1,head[N<<1],to[N<<1],nxt[N<<1];
void add_edge(int x,int y) {
    nxt[++cnt]=head[x]; to[cnt]=y; head[x]=cnt;
}

void dfs1(int x,int fa,int dep) {  //点x的父亲为fa深度为dep 
    tree[x].dep=dep;
    tree[x].fa=fa;
    tree[x].sz=1;
    tree[x].val=v[x];
    int maxson=-1;
    for (int i=head[x];i;i=nxt[i]) {
        int y=to[i];
        if (y==fa) continue;
        dfs1(y,x,dep+1);
        tree[x].sz+=tree[y].sz;
        if (tree[y].sz>maxson) tree[x].heavy=y,maxson=tree[y].sz;
    }
}

void dfs2(int x,int top) {  //点x所在树链的top 
    tree[x].toseg=++num;
    tree[x].top=top;
    totree[num]=x;
    if (!tree[x].heavy) return;  //叶子结点 
    dfs2(tree[x].heavy,top);  //先剖分重儿子 
    for (int i=head[x];i;i=nxt[i]) {  //再剖分轻儿子 
        int y=to[i];
        if (y==tree[x].fa || y==tree[x].heavy) continue;
        dfs2(y,y);
    }
}

int query2(int x,int y) {  //查询x到y路径的值,原理同上 
    int ret=0;
    while (tree[x].top!=tree[y].top) {
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);
        int t=query(1,1,n,tree[tree[x].top].toseg,tree[x].toseg);
        if (tree[ret].dep>tree[t].dep) ret=t;
        x=tree[tree[x].top].fa;
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);
    int t=query(1,1,n,tree[x].toseg,tree[y].toseg);
    if (tree[ret].dep>tree[t].dep) ret=t;
    return ret;
}

int main()
{
    cin>>n>>m;
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        add_edge(x,y); add_edge(y,x);
    }
    rt=1;
    dfs1(rt,0,1);  //树链剖分准备信息 
    dfs2(rt,rt);  //开始树链剖分 
    build(1,1,n);  //把树链建成线段树 

    tree[0].dep=INF;
    for (int i=1;i<=m;i++) {
        int opt,x; scanf("%d%d",&opt,&x);
        if (opt==0) {
            if (v[x]==0) update(1,1,n,tree[x].toseg,tree[x].toseg,x);
            if (v[x]==1) update(1,1,n,tree[x].toseg,tree[x].toseg,0);
            v[x]=1-v[x];
        } else {
            int t=query2(1,x);
            if (t==0) puts("-1");
            else printf("%d
",t);
        }
    }
    return 0;
}
View Code

洛谷P4092 HEOI2016/TJOI2016 树

这道题和上题差不多,唯一变化只是这里要维护的是区间标记点深度最大点。

#pragma comment(linker,"/STACK:102400000,102400000")
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
const int INF=0x3f3f3f3f;
typedef long long LL;
int n,m,rt,num=0,v[N];
struct node{
    int dep,fa,val,sz,heavy;
    int toseg,top; 
}tree[N];  //原本的树 
int totree[N<<2];  //线段树点i代表原树的点totree[i] 

/*-------------------------以下为线段树-----------------------------*/
LL Max[N<<2];
void pushup(int rt) {
    if (tree[Max[rt<<1]].dep>tree[Max[rt<<1|1]].dep) Max[rt]=Max[rt<<1];
    else Max[rt]=Max[rt<<1|1];
}

void build(int rt,int l,int r) {
    Max[rt]=0;
    if (l==r) return;
    int mid=l+r>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
}

void update(int rt,int l,int r,int ql,int qr,int v) {
    if (ql<=l && r<=qr) {
        Max[rt]=v;
        return;
    }
    int mid=l+r>>1;
    if (ql<=mid) update(rt<<1,l,mid,ql,qr,v);
    if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,v);
    pushup(rt);
}

int query(int rt,int l,int r,int ql,int qr) {
    if (ql<=l && r<=qr) return Max[rt];
    int mid=l+r>>1,ret=0;
    if (ql<=mid) {
        int t=query(rt<<1,l,mid,ql,qr);
        if (tree[ret].dep<tree[t].dep) ret=t;
    }
    if (qr>mid) {
        int t=query(rt<<1|1,mid+1,r,ql,qr);
        if (tree[ret].dep<tree[t].dep) ret=t;
    }
    return ret;
}

/*--------------------------以下为树链剖分----------------------------*/
int cnt=1,head[N<<1],to[N<<1],nxt[N<<1];
void add_edge(int x,int y) {
    nxt[++cnt]=head[x]; to[cnt]=y; head[x]=cnt;
}

void dfs1(int x,int fa,int dep) {  //点x的父亲为fa深度为dep 
    tree[x].dep=dep;
    tree[x].fa=fa;
    tree[x].sz=1;
    tree[x].val=v[x];
    int maxson=-1;
    for (int i=head[x];i;i=nxt[i]) {
        int y=to[i];
        if (y==fa) continue;
        dfs1(y,x,dep+1);
        tree[x].sz+=tree[y].sz;
        if (tree[y].sz>maxson) tree[x].heavy=y,maxson=tree[y].sz;
    }
}

void dfs2(int x,int top) {  //点x所在树链的top 
    tree[x].toseg=++num;
    tree[x].top=top;
    totree[num]=x;
    if (!tree[x].heavy) return;  //叶子结点 
    dfs2(tree[x].heavy,top);  //先剖分重儿子 
    for (int i=head[x];i;i=nxt[i]) {  //再剖分轻儿子 
        int y=to[i];
        if (y==tree[x].fa || y==tree[x].heavy) continue;
        dfs2(y,y);
    }
}

int query2(int x,int y) {  //查询x到y路径的值,原理同上 
    int ret=0;
    while (tree[x].top!=tree[y].top) {
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);
        int t=query(1,1,n,tree[tree[x].top].toseg,tree[x].toseg);
        if (tree[ret].dep<tree[t].dep) ret=t;
        x=tree[tree[x].top].fa;
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);
    int t=query(1,1,n,tree[x].toseg,tree[y].toseg);
    if (tree[ret].dep<tree[t].dep) ret=t;
    return ret;
}

int main()
{
    cin>>n>>m;
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        add_edge(x,y); add_edge(y,x);
    }
    rt=1;
    dfs1(rt,0,1);  //树链剖分准备信息 
    dfs2(rt,rt);  //开始树链剖分 
    build(1,1,n);  //把树链建成线段树 

    tree[0].dep=0;
    update(1,1,n,tree[1].toseg,tree[1].toseg,1);
    for (int i=1;i<=m;i++) {
        char opt[3]; int x; scanf("%s%d",&opt,&x);
        if (opt[0]=='C') {
            update(1,1,n,tree[x].toseg,tree[x].toseg,x);
        } else {
            int t=query2(1,x);
            printf("%d
",t);
        }
    }
    return 0;
}
View Code

洛谷P2486 [SDOI2011]染色

这道题要比上面两道复杂一些,但是其实也就是树链剖分后线段树花样操作。这道题要求的维护信息是一段路径上不同颜色的段数,那么我们维护线段树信息:sum[rt]代表区间颜色段数,L[rt]代表区间最左端的颜色,R[rt]代表区间最右端颜色,tag[rt]为懒惰标记。那么L[rt]=L[rt<<1],R[rt]=R[rt<<1|1],sum[rt]=sum[rt<<1]+sum[rt<<1|1],这里值得注意的是如果R[rt<<1]==L[rt<<1|1]的话,sum[rt]--,意思是中间连成一段的话段数减一。

还有一个细节就是,查询答案时注意沿着重链往上跳的时候也会出现上诉中间相同连成段的情况需要特别判断,具体来说就是如果x.top和x.top.fa颜色相同的话就ans-- 。

#pragma comment(linker,"/STACK:102400000,102400000")
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int n,m,num=0,v[N];
struct node{
    int dep,fa,val,sz,heavy;
    int toseg,top; 
}tree[N];  //原本的树 
int totree[N<<2];  //线段树点i代表原树的点totree[i] 

/*-------------------------以下为线段树-----------------------------*/
int sum[N<<2],L[N<<2],R[N<<2],tag[N<<2];
void pushdown(int rt,int len1,int len2) {
    if (!tag[rt]) return;
    int lc=rt<<1,rc=rt<<1|1;
    sum[lc]=1; L[lc]=R[lc]=tag[lc]=tag[rt];
    sum[rc]=1; L[rc]=R[rc]=tag[rc]=tag[rt];
    tag[rt]=0;
}
void pushup(int rt) {
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
    if (R[rt<<1]==L[rt<<1|1]) sum[rt]--;
    L[rt]=L[rt<<1]; R[rt]=R[rt<<1|1];
}

void build(int rt,int l,int r) {
    if (l==r) {
        sum[rt]=1; L[rt]=R[rt]=tree[totree[l]].val; tag[rt]=0;
        return;
    }
    int mid=l+r>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
}

void update(int rt,int l,int r,int ql,int qr,int v) {
    if (ql<=l && r<=qr) {
        sum[rt]=1; L[rt]=R[rt]=tag[rt]=v;
        return;
    }
    int mid=l+r>>1;
    pushdown(rt,mid-l+1,r-mid);
    if (ql<=mid) update(rt<<1,l,mid,ql,qr,v);
    if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,v);
    pushup(rt);
}

int query(int rt,int l,int r,int ql,int qr) {
    if (ql<=l && r<=qr) return sum[rt];
    int mid=l+r>>1,ret=0;
    pushdown(rt,mid-l+1,r-mid);
    if (ql<=mid) ret=(ret+query(rt<<1,l,mid,ql,qr));
    if (qr>mid) ret=(ret+query(rt<<1|1,mid+1,r,ql,qr));
    if (ql<=mid && qr>mid && R[rt<<1]==L[rt<<1|1]) ret--;
    return ret;
}

int query3(int rt,int l,int r,int x) {
    if (x<=l && r<=x) return L[rt];
    int mid=l+r>>1;
    pushdown(rt,mid-l+1,r-mid);
    if (x<=mid) return query3(rt<<1,l,mid,x);
    if (x>mid) return query3(rt<<1|1,mid+1,r,x);
}

/*--------------------------以下为树链剖分----------------------------*/
int cnt=1,head[N<<1],to[N<<1],nxt[N<<1];
void add_edge(int x,int y) {
    nxt[++cnt]=head[x]; to[cnt]=y; head[x]=cnt;
}

void dfs1(int x,int fa,int dep) {  //点x的父亲为fa深度为dep 
    tree[x].dep=dep;
    tree[x].fa=fa;
    tree[x].sz=1;
    tree[x].val=v[x];
    int maxson=-1;
    for (int i=head[x];i;i=nxt[i]) {
        int y=to[i];
        if (y==fa) continue;
        dfs1(y,x,dep+1);
        tree[x].sz+=tree[y].sz;
        if (tree[y].sz>maxson) tree[x].heavy=y,maxson=tree[y].sz;
    }
}

void dfs2(int x,int top) {  //点x所在树链的top 
    tree[x].toseg=++num;
    tree[x].top=top;
    totree[num]=x;
    if (!tree[x].heavy) return;  //叶子结点 
    dfs2(tree[x].heavy,top);  //先剖分重儿子 
    for (int i=head[x];i;i=nxt[i]) {  //再剖分轻儿子 
        int y=to[i];
        if (y==tree[x].fa || y==tree[x].heavy) continue;
        dfs2(y,y);
    }
}

//以下两个函数是树链剖分的精髓 
void update2(int x,int y,int z) {  //修改x到y路径的值 
    while (tree[x].top!=tree[y].top) {  //不在同一条链上 
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);  //x为深度大的链 
        update(1,1,n,tree[tree[x].top].toseg,tree[x].toseg,z);  //x向上跳的同时更新 
        x=tree[tree[x].top].fa;  //深度大的向上跳 
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);  //这里x和y在同一条链 
    update(1,1,n,tree[x].toseg,tree[y].toseg,z);  //x和y这条链的更新 
}

int query2(int x,int y) {  //查询x到y路径的值,原理同上 
    int ret=0;
    while (tree[x].top!=tree[y].top) {
        if (tree[tree[x].top].dep<tree[tree[y].top].dep) swap(x,y);
        ret=(ret+query(1,1,n,tree[tree[x].top].toseg,tree[x].toseg));
        int t1=query3(1,1,n,tree[tree[x].top].toseg);
        int t2=query3(1,1,n,tree[tree[tree[x].top].fa].toseg);
        if (t1==t2) ret--;
        x=tree[tree[x].top].fa;
    }
    if (tree[x].dep>tree[y].dep) swap(x,y);
    ret=(ret+query(1,1,n,tree[x].toseg,tree[y].toseg));
    return ret;
}

int main()
{
    cin>>n>>m;
    for (int i=1;i<=n;i++) scanf("%d",&v[i]);
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        add_edge(x,y); add_edge(y,x);
    }
    dfs1(1,0,1);  //树链剖分准备信息 
    dfs2(1,1);  //开始树链剖分 
    build(1,1,n);  //把树链建成线段树 
    
    for (int i=1;i<=m;i++) {
        char opt[3]; int x,y,z; scanf("%s",&opt);
        if (opt[0]=='C') {
            scanf("%d%d%d",&x,&y,&z);
            update2(x,y,z);
        } else {
            scanf("%d%d",&x,&y);
            printf("%d
",query2(x,y));
        }
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/clno1/p/10858546.html