树链剖分

https://www.cnblogs.com/chinhhh/p/7965433.html<-详解

洛谷P3384<-模板

如果我会告诉你我不小心把%写成了*调了一个多小时吗

适用范围

在一棵树上+-点权/边权然后多次提问的问题

原理

将一棵树剖分成若干条链,在链上通过数据结构维护。

照本宣科

将一个点的子树中最大的那个做为重儿子,其他的叫做轻儿子,对于所有点的重儿子连的链称为重链。我们将重链作为主链,将轻儿子构成的其他轻链加在重链的后面。很明显,对于每一条链,它的节点都是相连的(废话)。对于每一个子树,它的所有的节点都排在子树根的后面

所以我们就把一个树成功地退化成了一条链。

因为这条链有如上我所说的几个性质,题目要求如果是能在链上连续维护的(如求子树权值和,求任意两点间距离和),我们就可以用数据结构维护它了,比如线段树。

构造

首先我们需要两次dfs。

dfs1

用于建树,顺便记录每个节点的父亲 和 该点的深度 和 它的子树的大小 和 它的重儿子。

fa[]父亲节点 dep[]深度 siz[]子树大小 son[]重儿子(重儿子为子树大者)

void dfs1(int x,int f,int depth)
{
    siz[x]=1; dep[x]=depth; fa[x]=f;
    int maxson=-1;
    for(R int i=head[x];i;i=e[i].nxt)
    {
        int u=e[i].to;
        if(u==f)continue;
        dfs1(u,x,depth+1);
        siz[x]+=siz[u];
        if(siz[u]>maxson)son[x]=u,maxson=siz[u];
    }
}

dfs2

用于将一个树退化成链,记录节点在链的编号 和 节点的链首节点 和 链上节点的权值。

cnt时间戳 id[]编号 top[]链首节点 a[]原值 w[]链上的值(便于维护)

void dfs2(int x,int topf)
{
    id[x]=++cnt;
    w[cnt]=a[x];
    top[x]=topf;
    if(!son[x])return;
    dfs2(son[x],topf);
    for(R int i=head[x];i;i=e[i].nxt)
    {
        int u=e[i].to;
        if(u==fa[x]||u==son[x])continue;
        dfs2(u,u);
    }
}

求两点间距离

要求两点间距离:

若两点在一条链上(top[x]==top[y])我们直接求两点间的距离即可。

若两点不在一条链上,那么求更深的那个点x到此刻链首top[x]的距离,然后令x=fa[top[x]],即可将x更新到新链上。重复操作,每次只将深度更大的点向上更新。最终两点会处于同一条链(重链或轻链,最远是重链)上,然后再加上两点间的和就可以了。

用线段树维护。

inline int queryrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        res=0;
        st.query(1,id[top[x]],id[x]);
        ans=(ans+res)%mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    res=0;
    st.query(1,id[x],id[y]);
    ans=(ans+res)%mod;
    return ans;
}

更新两点间距离

和上面是一样的,分成若干条链更新就可以了。

inline void updaterange(int x,int y,int k){
    k%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        st.update(1,id[top[x]],id[x],k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    st.update(1,id[x],id[y],k);
}

更新/查询子树权值(和)

如上,子树节点在链上一定是在根的后面并且连续的。

所以要更新以x为根节点的所有子树结点,就更新id[x]~id[x]+size[x]-1的范围即可。

(代码放在代码里)

例题(在上面)

当然要根据题目需要做各种修改,这真的只是一个板子。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#define R register
using namespace std;
inline int read()
{
    int x=0,w=0;char c=getchar();
    while(!isdigit(c))w|=c=='-',c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return w?-x:x;
}
const int maxn=1e5+5,maxm=1e5+5;
int n,m,root,mod,fa[maxn],dep[maxn],son[maxn],siz[maxn],top[maxn],id[maxn],a[maxn],w[maxn];
struct Edge{
    int to,nxt;
}e[maxn*2];
int ecnt,head[maxn];
inline void addedge(int from,int to)
{
    e[++ecnt]=(Edge){to,head[from]};head[from]=ecnt;
    e[++ecnt]=(Edge){from,head[to]};head[to]=ecnt;
}
int res;
struct SegmentTree{
    #define ls (ro<<1)
    #define rs (ro<<1|1)
    
    struct tree{
        int l,r,val,tag;
    }e[maxn<<2];
    inline void push_up(int ro)
    {
        e[ro].val=(e[ls].val+e[rs].val)%mod;
    }
    inline void push_down(int ro)
    {
        e[ls].tag+=e[ro].tag;
        e[rs].tag+=e[ro].tag;
        e[ls].val+=e[ro].tag*(e[ls].r-e[ls].l+1);
        e[rs].val+=e[ro].tag*(e[rs].r-e[rs].l+1);
        e[ls].val%=mod;
        e[rs].val%=mod;
        e[ro].tag=0;
    }
    void build(int ro,int l,int r)
    {
        e[ro].l=l;e[ro].r=r;
        if(l==r){e[ro].val=w[l]%mod;return ;}
        int mid=(l+r)>>1;
        build(ls,l,mid);build(rs,mid+1,r);
        push_up(ro);
    }
    void update(int ro,int l,int r,int k)
    {
        if(e[ro].l>=l and e[ro].r<=r){
            e[ro].tag+=k;
            e[ro].val+=k*(e[ro].r-e[ro].l+1);
            e[ro].val%=mod;
            return;
        }
        int mid=(e[ro].l+e[ro].r)>>1;
        if(e[ro].tag)push_down(ro);
        if(l<=mid)update(ls,l,r,k);
        if(r>mid)update(rs,l,r,k);
        push_up(ro);
    }
    void query(int ro,int l,int r)
    {
        if(e[ro].l>=l and e[ro].r<=r){
            res+=e[ro].val;res%=mod;return;
        }
        if(e[ro].tag)push_down(ro);
        int mid=(e[ro].l+e[ro].r)>>1;
        if(l<=mid)query(ls,l,r);
        if(r>mid)query(rs,l,r);
    }
    #undef ls
    #undef rs
}st;

void dfs1(int x,int f,int depth)
{
    siz[x]=1; dep[x]=depth; fa[x]=f;
    int maxson=-1;
    for(R int i=head[x];i;i=e[i].nxt)
    {
        int u=e[i].to;
        if(u==f)continue;
        dfs1(u,x,depth+1);
        siz[x]+=siz[u];
        if(siz[u]>maxson)son[x]=u,maxson=siz[u];
    }
}
int cnt;
void dfs2(int x,int topf)
{
    id[x]=++cnt;
    w[cnt]=a[x];
    top[x]=topf;
    if(!son[x])return;
    dfs2(son[x],topf);
    for(R int i=head[x];i;i=e[i].nxt)
    {
        int u=e[i].to;
        if(u==fa[x]||u==son[x])continue;
        dfs2(u,u);
    }
}
inline void updaterange(int x,int y,int k){
    k%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        st.update(1,id[top[x]],id[x],k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    st.update(1,id[x],id[y],k);
}
inline int queryrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        res=0;
        st.query(1,id[top[x]],id[x]);
        ans=(ans+res)%mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    res=0;
    st.query(1,id[x],id[y]);
    ans=(ans+res)%mod;
    return ans;
}

inline void debug()
{
    for(int i=1;i<=n;i++)
    {
        res=0;st.query(1,id[i],id[i]);
        printf("%d ",res);
    }
}

int main()
{
    n=read(),m=read(),root=read(),mod=read();
    for(R int i=1;i<=n;i++)a[i]=read();
    for(R int i=1;i<n;i++)addedge(read(),read());
    dfs1(root,0,1);
    dfs2(root,root);
    st.build(1,1,n);
    while(m--)
    {
        int a,b,c;
        switch(read())
        {
            case(1):{
                a=read(),b=read(),c=read();
                updaterange(a,b,c);
                break;
            }
            case(2):{
                a=read(),b=read();
                printf("%d
",queryrange(a,b));
                break;
            }
            case(3):{
                int a=read(),b=read();
                st.update(1,id[a],id[a]+siz[a]-1,b);
                break;
            }
            case(4):{
                int a=read();res=0;
                st.query(1,id[a],id[a]+siz[a]-1);
                printf("%d
",res);
                break;
            }
        }
    }
    //debug();
    return 0;
}
原文地址:https://www.cnblogs.com/BrotherHood/p/13124187.html