重链剖分

处理大规模树上路径问题,复杂度O(n logn)

重链剖分:将树上路径划分为若干条不相交的链,来把树形结构转化为线性结构,利用线段树处理大规模树上路径问题。

P3384 【模板】轻重链剖分

操作 11: 格式: 1 x y z1 x y z 表示将树从 xx 到 yy 结点最短路径上所有节点的值都加上 zz。

操作 22: 格式: 2 x y2 x y 表示求树从 xx 到 yy 结点最短路径上所有节点的值之和。

操作 33: 格式: 3 x z3 x z 表示将以 xx 为根节点的子树内所有节点值都加上 zz。

操作 44: 格式: 4 x4 x 表示求以 xx 为根节点的子树内所有节点值之和

对于操作1,直接在线段树上修改即可

操作三:两个点一起跳到链的顶端,线段树上更新即可。

/*
支持树上路径统计,修改,查询子树,修改子树点权
*/
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+500;
ll mod;
ll w[N];
struct SEG{

    ll sum[N*10],add[N*10];
    void push_up(int rt){
        sum[rt]=sum[(rt<<1)]+sum[rt<<1 | 1]%mod;
    }

    void push_down(int rt,int m){
        if(add[rt]){
            add[rt<<1]+=add[rt];
            add[rt<<1 | 1]+=add[rt];
            sum[rt<<1]+=((m-(m>>1))*add[rt])%mod;
            sum[rt<<1 | 1]+=((m>>1)*add[rt])%mod;
            add[rt]=0;
        }
    }
    
    #define lson l,mid,rt<<1
    #define rson mid+1,r,rt<<1 | 1

    void build(int l,int r,int rt){
        add[rt]=0;
        if(l==r){sum[rt]=w[l];return;}
        int mid=(l+r)>>1;
        build(lson);

        build(rson);
        push_up(rt);
    }

    void update(int a,int b,ll c,int l,int r,int rt){//更新a到b区间和
        if(a<=l&&b>=r){
            sum[rt]+=(r-l+1)*c%mod;
            add[rt]+=c;
            return ;
        }
        push_down(rt,r-l+1);
        int mid=(r+l)>>1;
        if(a<=mid)update(a,b,c,lson);
        if(b>mid)update(a,b,c,rson);
        push_up(rt);
    }

    ll query(int a,int b,int l,int r,int rt){//查询a到b区间和
        if(a<=l&&b>=r){return sum[rt];}
        push_down(rt,r-l+1);
        ll ans=0;
        int mid=(r+l)>>1;
        if(a<=mid)ans=(ans+query(a,b,lson))%mod;
        if(b>mid)ans=(ans+query(a,b,rson))%mod;
        return ans%mod;
    }

}seg;
int n,ecnt;
int top[N],Tim;ll val[N];
int dep[N],size[N],fa[N],son[N],dfn[N],head[N];
struct edge{
    int v,next;
}e[N*10];
void init(){
    memset(head,-1,sizeof head);
    memset(dep,0,sizeof dep);
 
    ecnt=Tim=0;
}
void add(int u,int v){
    e[ecnt].v=v;e[ecnt].next=head[u];head[u]=ecnt++;
}
void dfs_size(int u,int f){//找出重儿子
    size[u]=1;
    dep[u]=dep[f]+1;
    fa[u]=f;
    int maxsize=-1;
    for(int i=head[u];~i;i=e[i].next){
        int v=e[i].v;
        if(v==f)continue;
        dfs_size(v,u);
        size[u]+=size[v];
        if(size[v]>maxsize){
            maxsize=size[v];
            son[u]=v;
        }
    }
}
void dfs_link(int u,int t){//重链剖分
    dfn[u]=++Tim;
    top[u]=t;
    w[Tim]=val[u];
    if(!son[u])return ;
    dfs_link(son[u],t);
    for(int i=head[u];~i;i=e[i].next){
        int v=e[i].v;
        if(v==fa[u]||v==son[u])continue;
        dfs_link(v,v);
    }
}

void mchain(int x,int y,int z){//x到y的路径添加z
    z%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        seg.update(dfn[ top[x] ],dfn[x],z,1,n,1);
        x=fa[top[x]];      
    }
    if(dep[x]>dep[y])swap(x,y);
    seg.update(dfn[x],dfn[y],z,1,n,1);

}

ll qchain(int x,int y){//查询x到y的路径和
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        res+=seg.query(dfn[ top[x] ],dfn[x],1,n,1);
        res%=mod;
        x=fa[top[x]];      
    }
    if(dep[x]>dep[y])swap(x,y);
    res+=seg.query(dfn[x],dfn[y],1,n,1);
    return res%mod;
}

int lca(int u,int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) {
            u = fa[top[u]];
        } else {
            v = fa[top[v]];
        }
    }
    return dep[u] > dep[v] ? v : u;
} 
void mson(int x,int z){//修改x的子树
    seg.update(dfn[x],dfn[x]+size[x]-1,z,1,n,1);
}
ll qson(int x){//查询x节点的子树点权和
    return seg.query(dfn[x],dfn[x]+size[x]-1,1,n,1)%mod;
}

int main(){
    int m,root;
    init();
    scanf("%d %d %d %lld",&n,&m,&root,&mod);
    for(int i=1;i<=n;i++)scanf("%lld",&val[i]);
    for(int i=1;i<n;i++){
        int u,v;scanf("%d %d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs_size(root,-1);
    dfs_link(root,root);
    seg.build(1,n,1);
    while(m--){
        int opt,x,y,z;
        scanf("%d",&opt);
        
        if(opt==1){
            scanf("%d %d %d",&x,&y,&z);
            mchain(x,y,z);
        }
        
        else if(opt==2){
            scanf("%d %d",&x,&y);
            // cout<<"ans :";
            printf("%lld
",qchain(x,y));
        }

        else if(opt==3){
            scanf("%d %d",&x,&y);
            mson(x,y);
        }

        else if(opt==4){
            scanf("%d",&x);
            // cout<<"ans :";

            printf("%lld
",qson(x));
        }

    }

    // system("pause");
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/littlerita/p/13473167.html