树链剖分学习

树链剖分,顾名思义,就是将一棵树上的节点按照一个特殊的方式重新编号,这样我们就可以利用一些数据结构去优化加速一些树上的操作;

现在要介绍的是重链剖分;

首先明确一些概念:

重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;

轻儿子:父亲节点中除了重儿子以外的儿子;

重边:父亲结点和重儿子连成的边;

轻边:父亲节点和轻儿子连成的边;

重链:由多条重边连接而成的路径;

轻链:由多条轻边连接而成的路径;

有了这些概念,我们就可以愉快地剖分了;

具体操作就是先用两个 dfs作出以下变量:

名称 解释
fa[u] 保存结点u的父亲节点
dep[u] 保存结点u的深度值
size[u] 保存以u为根的子树节点个数
son[u] 保存重儿子
rk[u] 保存当前dfs标号在树中所对应的节点
top[u] 保存当前节点所在链的顶端节点
dfn[u] 保存树中每个节点剖分以后的新编号(DFS的执行顺序)

然后在写一棵线段树(某数据结构),将树上节点以 dfs序映射到线段上,然后就可以优化树上操作了,这就是树链剖分;

附上代码:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e5+10;

int n,m,r,mod;
int val[N];
int dfn[N],top[N],rk[N];
int dep[N],fa[N],size[N],son[N];
struct node{
    int l,r,ls,rs,sum,lazy;
}a[N<<4];
struct edge{
    int next,to;
}e[N<<4];
int head[N],cnt;

void add(int u,int v){
    e[++cnt]=(edge){head[u],v};
    head[u]=cnt;
}

void dfs1(int x){
    size[x]=1;
    dep[x]=dep[fa[x]]+1;
    for(int v,i=head[x];i;i=e[i].next){
        v=e[i].to;
        if(dep[v]) continue;
        fa[v]=x;
        dfs1(v);
        size[x]+=size[v];
        if(size[v]>size[son[x]]) son[x]=v;
    }
}

void dfs2(int x,int t){
    top[x]=t;
    dfn[x]=++cnt;
    rk[cnt]=x;
    if(son[x]) dfs2(son[x],t);
    for(int i=head[x],v;i;i=e[i].next){
        v=e[i].to;
        if(v==fa[x]) continue;
        if(son[x]!=v) dfs2(v,v);
    }
}

void pushup(int o){a[o].sum=(a[a[o].ls].sum+a[a[o].rs].sum)%mod;}

void build(int o,int l,int r){
    if(l==r){
        a[o].sum=val[rk[l]];
        a[o].l=a[o].r=l;
        return ;
    }
    int mid=(l+r)>>1;
    a[o].ls=++cnt,a[o].rs=++cnt;
    build(a[o].ls,l,mid);
    build(a[o].rs,mid+1,r);
    a[o].l=a[a[o].ls].l;
    a[o].r=a[a[o].rs].r;
    pushup(o);
}

void pushdown(int o){
    if(a[o].lazy){
        int ls=a[o].ls,rs=a[o].rs;
        a[ls].lazy=(a[ls].lazy+a[o].lazy)%mod;
        a[rs].lazy=(a[rs].lazy+a[o].lazy)%mod;
        a[ls].sum=(a[ls].sum+(a[ls].r-a[ls].l+1)*a[o].lazy)%mod;
        a[rs].sum=(a[rs].sum+(a[rs].r-a[rs].l+1)*a[o].lazy)%mod;
        a[o].lazy=0;
    }
}

void updata(int o,int x,int y,int d){
    if(a[o].l>=x&&a[o].r<=y){
        a[o].lazy+=d;
        a[o].sum=(a[o].sum+(a[o].r-a[o].l+1)*d)%mod;
        return ;
    }
    pushdown(o);
    int mid=(a[o].l+a[o].r)>>1;
    if(x<=mid) updata(a[o].ls,x,y,d);
    if(y>mid) updata(a[o].rs,x,y,d);
    pushup(o);
}

int query(int o,int x,int y){
    if(a[o].l>=x&&a[o].r<=y) return a[o].sum;
    pushdown(o);
    int mid=(a[o].l+a[o].r)>>1;
    int rel=0;
    if(x<=mid) rel=(rel+query(a[o].ls,x,y))%mod;
    if(y>mid) rel=(rel+query(a[o].rs,x,y))%mod;
    return rel;
}

int getsum(int x,int y){
    int rel=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        rel=(rel+query(1,dfn[top[x]],dfn[x]))%mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    rel=(rel+query(1,dfn[x],dfn[y]))%mod;
    return rel;
}

int change(int x,int y,int d){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        updata(1,dfn[top[x]],dfn[x],d);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    updata(1,dfn[x],dfn[y],d);
}

int main()
{
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    for(int i=1,x,y;i<n;++i){
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    cnt=0,dfs1(r),dfs2(r,r);
    build(1,1,n);
    for(int i=1,op,x,y,z;i<=m;++i){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%d",&x,&y,&z);
            change(x,y,z);
        }
        if(op==2){
            scanf("%d%d",&x,&y);
            printf("%d
",getsum(x,y));
        }
        if(op==3){
            scanf("%d%d",&x,&z);
            updata(1,dfn[x],dfn[x]+size[x]-1,z);
        }
        if(op==4){
            scanf("%d",&x);
            printf("%d
",query(1,dfn[x],dfn[x]+size[x]-1));
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/nnezgy/p/11578685.html