树链剖分模板

#include<bits/stdc++.h>
using namespace std;
const int N=4e5+10;
typedef long long ll;
int n,m,r,p;
int id[N],w[N],pre[N],top[N],h[N],e[N],ne[N],idx,times;
int fa[N],sz[N],son[N],depth[N];
struct node{
    int l,r;
    ll sum;
    ll lazy;
}tr[N];
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u){//算父亲和重儿子
    int i;
    sz[u]=1;
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa[u])
            continue;
        fa[j]=u;
        depth[j]=depth[u]+1;
        dfs(j);
        sz[u]+=sz[j];
        if(sz[j]>sz[son[u]]){
            son[u]=j;
        }
    }
}
void dfs1(int u,int x){
    pre[u]=++times;
    id[times]=u;
    top[u]=x;
    if(!son[u])
        return ;
    dfs1(son[u],x);
    for(int i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==son[u]||j==fa[u])
            continue;
        dfs1(j,j);
    }
}
void pushup(int u){
    tr[u].sum=(tr[u<<1].sum+tr[u<<1|1].sum)%p;
}
void build(int u,int l,int r){
    if(l==r){
        tr[u]={l,r,w[id[l]]%p,0};
    }
    else{
        tr[u]={l,r};
        int mid=l+r>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        pushup(u);
    }
}
void pushdown(int u){
    auto & t=tr[u];
    auto & l=tr[u<<1];
    auto & r=tr[u<<1|1];
    if(t.lazy){
        l.lazy+=t.lazy;
        l.lazy%=p;
        l.sum+=(ll)(l.r-l.l+1)*t.lazy;
        l.sum%=p;
        r.lazy+=t.lazy;
        r.sum+=(ll)(r.r-r.l+1)*t.lazy;
        r.lazy%=p;
        r.sum%=p;
        t.lazy=0;
    }
}
void modify(int u,int l,int r,int d){
    if(tr[u].l>=l&&tr[u].r<=r){
        tr[u].sum+=(ll)(tr[u].r-tr[u].l+1)*d;
        tr[u].sum%=p;
        tr[u].lazy+=d;
        tr[u].lazy%=p;
    }
    else{
        pushdown(u);
        int mid=tr[u].l+tr[u].r>>1;
        if(l<=mid)
        modify(u<<1,l,r,d);
        if(r>mid)
        modify(u<<1|1,l,r,d);
        pushup(u);

    }
}
ll query(int u,int l,int r){
    if(tr[u].l>=l&&tr[u].r<=r)
    return tr[u].sum%p;
    else{
        pushdown(u);
        int mid=tr[u].l+tr[u].r>>1;
        ll res=0;
        if(l<=mid)
        res=query(u<<1,l,r);
        if(r>mid)
        res+=query(u<<1|1,l,r);
        return res;
    }
}
void uppath(int x,int y,int z){
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]])
            swap(x,y);
        modify(1,pre[top[x]],pre[x],z);
        x=fa[top[x]];
    }
    if(depth[x]>depth[y])
        swap(x,y);
    modify(1,pre[x],pre[y],z);
}

int qpath(int x,int y){
    int res=0;
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]])
            swap(x,y);
        res=(res+query(1,pre[top[x]],pre[x]))%p;
        x=fa[top[x]];
    }
    if(depth[x]>depth[y])
        swap(x,y);
    res=(res+query(1,pre[x],pre[y]))%p;
    return res;
}
void uptree(int x,int y){
    modify(1,pre[x],pre[x]+sz[x]-1,y);
}
int qtree(int x){
    //cout<<pre[x]<<" "<<pre[x]+sz[x]-1<<endl;
    return query(1,pre[x],pre[x]+sz[x]-1)%p;
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n>>m>>r>>p;
    int i;
    memset(h,-1,sizeof h);
    for(i=1;i<=n;i++){
        cin>>w[i];
    }
    for(i=1;i<n;i++){
        int a,b;
        cin>>a>>b;
        add(a,b);
        add(b,a);
    }
    dfs(r);
    dfs1(r,r);
    build(1,1,n);
    while(m--){
        int opt;
        cin>>opt;
        int x,y,z;
        if(opt==1){
            cin>>x>>y>>z;
            uppath(x,y,z);
        }
        else if(opt==2){
            cin>>x>>y;
            cout<<qpath(x,y)<<endl;
        }
        else if(opt==3){
            cin>>x>>z;
            uptree(x,z);
        }
        else{
            cin>>x;
            cout<<qtree(x)<<endl;
        }
    }
}
View Code
原文地址:https://www.cnblogs.com/ctyakwf/p/12972962.html