树链剖分(附带LCA和换根)——基于dfs序的树上优化

。。。。

有点懒;


需要先理解几个概念:

1. LCA

2. 线段树(熟练,要不代码能调一天)

3. 图论的基本知识(dfs序的性质)

这大概就好了;

定义

  1.重儿子:一个点所连点树size最大的,这个son被称为这个点的重儿子;

  2.轻儿子:一个点所连点除重儿子以外的都是轻儿子;

  3.重链:从一个轻儿子或根节点开始沿重儿子走所成的链;

步骤

  在代码里,结合代码更清晰。。。(其实是太懒了)

 有重点需要注意的东西在code中有提到,仔细看。。。。

#include<bits/stdc++.h>
#define maxn 100007
#define le(x) x<<1
#define re(x) x<<1|1
using namespace std;
int n,m,root,mod,a[maxn],head[maxn],fa[maxn],son[maxn],cnt,tag[maxn<<2];
//a:原始点值,fa:父亲节点,son:重儿子,tag:懒标记 
int top[maxn],sz[maxn],id[maxn],dep[maxn],w[maxn],cent,tr[maxn<<2];
//top:所在重链的头结点,sz:子树大小,id:dfs序,dep:深度 
//w:dfs序所对应的值(建线段树),tr:线段树 
struct node{
    int next,to;
}edge[maxn<<2];

template<typename type_of_scan>
inline void scan(type_of_scan &x){
    type_of_scan f=1;x=0;char s=getchar();
    while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
    while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
    x*=f;
}

inline void add(int u,int v){
    edge[++cent]=(node){head[u],v};head[u]=cent;
}
//-----------------------------------------------------线段树红色预警 
void push_up(int p){
    tr[p]=tr[le(p)]+tr[re(p)];
    tr[p]%=mod;
}

void build(int l,int r,int p){
    if(l==r){
        tr[p]=w[l];
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,le(p));
    build(mid+1,r,re(p));
    push_up(p);
}

void push_down(int l,int r,int p,int k){
    int mid=l+r>>1;
    tr[le(p)]+=k*(mid-l+1),tr[re(p)]+=k*(r-mid);
    tr[le(p)]%=mod,tr[re(p)]%=mod;
    tag[le(p)]+=k,tag[re(p)]+=k;
    tag[le(p)]%=mod,tag[re(p)]%=mod;
}

void r_add(int nl,int nr,int l,int r,int p,int k){
    if(nl<=l&&nr>=r){
        tr[p]+=k*(r-l+1);tag[p]+=k;
        tr[p]%=mod,tag[p]%=mod;
        return ;
    }
    push_down(l,r,p,tag[p]),tag[p]=0;
    int mid=(l+r)>>1;
    if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
    if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
    push_up(p);
}

int r_query(int nl,int nr,int l,int r,int p){
    int ans=0;
    if(nl<=l&&nr>=r) return tr[p];
    push_down(l,r,p,tag[p]),tag[p]=0;
    int mid=l+r>>1;
    if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p)),ans%=mod;
    if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p)),ans%=mod;
    push_up(p);
    return ans;
}

//-----------------------------------------------------线段树结束
//-----------------------------------------------------开始预处理 

void dfs1(int x){
    sz[x]=1;//sz初始化 
    int max_part=-1;//max_part更新寻找重儿子 
    for(int i=head[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(y==fa[x]) continue;
        fa[y]=x,dep[y]+=dep[x]+1;//更新子节点,准备开始继续dfs1 
        dfs1(y);sz[x]+=sz[y];//更新自身的sz数组 
        if(max_part<sz[y]) son[x]=y,max_part=sz[y];//更新重儿子 
    }
}
/*dfs1功能介绍
1.更新fa数组;
2.更新dep数组;
3.更新sz数组; 
4.更新son数组; 
*/ 

void dfs2(int x,int t){
    id[x]=++cnt,w[cnt]=a[x],top[x]=t;//更新dfs序,dfs序所对的值,重链头节点 
    if(!son[x]) return ;
    dfs2(son[x],t);
    for(int i=head[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}
/*dfs2功能介绍
1.更新id数组;
2.更新w数组;
3.更新top数组
*/ 

//------------------------------------------------预处理结束 
//------------------------------------------------开始主要操作 

//其实没有说的这么简单,这里重点是理解重链之间的跳跃方式,线段树的优化 
//一个性质:重链上的dfs序是连续的,dfs1在dfs2前的原因就在此 

int road_query(int x,int y){
    int ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下面往上跳 
        ans+=r_query(id[top[x]],id[x],1,n,1);//更新重链 
        ans%=mod;
        x=fa[top[x]];//跳到重链头的fa 
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=r_query(id[x],id[y],1,n,1);//已经在同一条重链上,直接加 
    return ans%mod;
}

int tree_query(int x){
    return r_query(id[x],id[x]+sz[x]-1,1,n,1)%mod;
}//一个性质:在同一颗子树上的dfs序是连续的 

void road_add(int x,int y,int k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        r_add(id[top[x]],id[x],1,n,1,k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    r_add(id[x],id[y],1,n,1,k);
    return ;
}//类比 

void tree_add(int x,int k){
    r_add(id[x],id[x]+sz[x]-1,1,n,1,k);
    return ;
}//相同的性质 

//-----------------------------------------------树链剖分 

int main(){
    scan(n),scan(m),scan(root),scan(mod);
    for(int i=1;i<=n;i++) scan(a[i]);
    for(int i=1,u,v;i<=n-1;i++)
        scan(u),scan(v),add(u,v),add(v,u);
    dfs1(root),dfs2(root,root),build(1,n,1);
    for(int i=1;i<=m;i++){
        int type,x,y,z;
        scan(type);
        if(type==1) scan(x),scan(y),scan(z),
            road_add(x,y,z);
        else if(type==2) scan(x),scan(y),
            printf("%d
",road_query(x,y));
        else if(type==3) scan(x),scan(z),
            tree_add(x,z);
        else if(type==4) scan(x),
            printf("%d
",tree_query(x));
    }
    return 0;
} 

好了,可以开始调代码了

拓展:

  树链剖分,作为一个优秀的暴力结构,以O(n logn logn)的时间复杂度完成路径查询,在子树查询做到了nlogn级别,所以不得不说其优秀;

  但是,它的作用远不及此:

  1.LCA查询:

    与倍增相同,树链剖分可以用logn的时间复杂度完成LCA查询(跳跃性好像更优),而他的初始化是两遍dfs O(n),理论上更优。

    可以猜测,LCA依旧运用重链跳法,然后比较即可,这里给出示范代码

int Lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}//只要看懂树链剖分的基本操作,这个很简单 

    可以看到,其实代码很短。。。

  2.换根操作:

    设现在的根是root,我们可以发现,换根对于路径上的操作并没有影响,但是子树操作就会影响了,所以我们分类讨论

      设u为我们要查的子树的根节点

      (1)如果root=u,那么子树即为整棵树;

      (2)设 lca 为root和u的LCA,这里可以用上面所讲的树链剖分做,如果lca!=u,那么root并不是u的子节点,所以对于查询并不影响,常规操作即可

      (3)如果lca=u,那么u节点的子树就是整颗树减去u-root这个路径上与u相挨的节点v的子树即可,这里给出logn求点v的方法

//前提条件:要求的节点相挨的节点u,必须是root的LCA 
int find(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 
        if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 
        x=fa[top[x]];//
    }
    if(dep[x]<dep[y]) swap(x,y);//让y最浅 
    return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 
}

    整个操作的代码层次感我写的还是比较清楚了

void tree_add(int x,int k){
    if(root==x) r_add(1,n,1,n,1,k);//CASE 1 
    else{
        int lca=Lca(x,root);
        if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 
        else{
            int dson=find(x,root);
            r_add(1,n,1,n,1,k);
            r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
        }//CASE 3 
    }
    return ;
}

ll tree_query(int x){
    if(root==x) return r_query(1,n,1,n,1);//CASE 1 
    else{
        int lca=Lca(x,root);
        if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 
        else{
            int dson=find(x,root);
            return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
        }//CASE 3 
    }
}

推荐评测网站LOJ 。。。(因为洛谷没有换根操作)

AC代码附上

#include<bits/stdc++.h>
#define maxn 100007
#define ol putchar('
')
#define le(x) x<<1
#define re(x) x<<1|1
#define ll long long
using namespace std;
int n,m,head[maxn],cent,dep[maxn],son[maxn],fa[maxn],vis[maxn];
int top[maxn],a[maxn],id[maxn],w[maxn],sz[maxn],cnt,ij,root;
ll tr[maxn<<3],tag[maxn<<3];
struct node{
    int next,to;
}edge[maxn<<3];

template<typename type_of_scan>
inline void scan(type_of_scan &x){
    type_of_scan f=1;x=0;char s=getchar();
    while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
    while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
    x*=f;
}
template<typename type_of_print>
inline void print(type_of_print x){
    if(x<0) putchar('-'),x=-x;
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

inline void add(int u,int v){
    edge[++cent]=(node){head[u],v};head[u]=cent;
}

void push_up(int p){
    tr[p]=tr[le(p)]+tr[re(p)];
}

void push_down(int l,int r,int p,ll k){
    int mid=l+r>>1;
    tr[le(p)]+=1ll*(mid-l+1)*k,
    tr[re(p)]+=1ll*(r-mid)*k,
    tag[le(p)]+=k,tag[re(p)]+=k;
}

void build(int l,int r,int p){
    if(l==r){
        tr[p]=w[l];
        return ;
    }
    int mid=l+r>>1;
    build(l,mid,le(p));
    build(mid+1,r,re(p));
    push_up(p);
}

void r_add(int nl,int nr,int l,int r,int p,int k){
    if(nl<=l&&nr>=r){
        tr[p]+=1ll*(r-l+1)*k;
        tag[p]+=1ll*k;
        return ;
    }
    push_down(l,r,p,tag[p]),tag[p]=0;
    int mid=l+r>>1;
    if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
    if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
    push_up(p);
}

ll r_query(int nl,int nr,int l,int r,int p){
    ll ans=0;
    if(nl<=l&&nr>=r) return tr[p];
    push_down(l,r,p,tag[p]),tag[p]=0;
    int mid=l+r>>1;
    if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p));
    if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p));
    return ans;
}

void dfs1(int x){
    sz[x]=1;int max_part=-1;vis[x]++;
    for(int i=head[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(y==fa[x]) continue;
        fa[y]=x;dep[y]=dep[x]+1;
        dfs1(y);sz[x]+=sz[y];
        if(max_part<sz[y]) son[x]=y,max_part=sz[y];
    }
}

void dfs2(int x,int t){
    id[x]=++cnt;w[cnt]=a[x];top[x]=t;
    if(!son[x]) return ;
    dfs2(son[x],t);
    for(int i=head[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(y==son[x]||fa[x]==y) continue;
        dfs2(y,y);
    }
}

int Lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}//只要看懂树链剖分的基本操作,这个很简单 


//前提条件:要求的节点相挨的节点u,必须是root的LCA 
int find(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 
        if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 
        x=fa[top[x]];//
    }
    if(dep[x]<dep[y]) swap(x,y);//让y最浅 
    return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 
}

void tree_add(int x,int k){
    if(root==x) r_add(1,n,1,n,1,k);//CASE 1 
    else{
        int lca=Lca(x,root);
        if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 
        else{
            int dson=find(x,root);
            r_add(1,n,1,n,1,k);
            r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
        }//CASE 3 
    }
    return ;
}

ll tree_query(int x){
    if(root==x) return r_query(1,n,1,n,1);//CASE 1 
    else{
        int lca=Lca(x,root);
        if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 
        else{
            int dson=find(x,root);
            return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
        }//CASE 3 
    }
}

void road_add(int x,int y,ll k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        r_add(id[top[x]],id[x],1,n,1,k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    r_add(id[x],id[y],1,n,1,k);
    return ;
}

ll road_query(int x,int y){
    ll ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans+=r_query(id[top[x]],id[x],1,n,1);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=r_query(id[x],id[y],1,n,1);
    return ans;
}

int main(){
//    freopen("cin.in","r",stdin);
//    freopen("co.out","w",stdout);
    scan(n);
    for(int i=1;i<=n;i++) scan(a[i]);
    for(int i=2,v;i<=n;i++) scan(v),add(i,v),add(v,i);
    dfs1(1),dfs2(1,1),build(1,n,1);root=1;
    scan(m);
    for(int i=1;i<=m;i++){
        int type,x,y,z;
        scan(type),scan(x);
        if(type==1) root=x;
        else if(type==2) scan(y),scan(z),road_add(x,y,z);
        else if(type==3) scan(z),tree_add(x,z);
        else if(type==4) scan(y),printf("%lld
",road_query(x,y));
        else if(type==5) printf("%lld
",tree_query(x));
    }
    return 0;
}

 

原文地址:https://www.cnblogs.com/waterflower/p/11239971.html