P3384 【模板】轻重链剖分

题目链接:https://www.luogu.com.cn/problem/P3384

板子题(入门博客):https://www.cnblogs.com/ivanovcraft/p/9019090.html

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
typedef long long ll;

struct st{
    int to,next;
}stm[maxn*2];int head[maxn],cnt;
void add(int u,int v){
    stm[cnt].to=v;
    stm[cnt].next=head[u];
    head[u]=cnt++;
}
ll a[maxn],mod;
int tsize[maxn],tson[maxn],top[maxn],rk[maxn],id[maxn],dep[maxn],pre[maxn],co;
void init(int n){
    for(int i=1;i<=n;i++){
             head[i]=-1;
                tson[i]=0;
        }
        co=cnt=0;
}
struct trem{
    ll sum,lazy;
}tre[maxn<<2];
void dfs1(int now,int fa,int d){
    tsize[now]=1;
    pre[now]=fa;
    dep[now]=d;
//  cout<<now<<endl;
    for(int i=head[now];~i;i=stm[i].next){
        int to=stm[i].to;
        if(to==fa)continue;
        dfs1(to,now,d+1);
        tsize[now]+=tsize[to];
        if(tsize[to]>tsize[tson[now]])tson[now]=to;
    }
}
void dfs2(int now,int Top){
    rk[++co]=now;
    id[now]=co;//编号->dfs序
    top[now]=Top;
    if(tson[now]==0)return ;
    dfs2(tson[now],Top);
    for(int i=head[now];~i;i=stm[i].next){
        int to=stm[i].to;
        if(to==pre[now]||to==tson[now])continue;
        dfs2(to,to);
    }
}
void pushup(int rt){
    tre[rt].sum=(tre[rt<<1].sum+tre[rt<<1|1].sum)%mod;
    return ;
}
void pushdown(int rt,int l,int r){
    if(tre[rt].lazy){
        tre[rt<<1].lazy=(tre[rt<<1].lazy+tre[rt].lazy)%mod;
        tre[rt<<1|1].lazy=(tre[rt<<1|1].lazy+tre[rt].lazy)%mod;
        int mid=(l+r)/2;
        tre[rt<<1].sum=(tre[rt<<1].sum+tre[rt].lazy*(mid-l+1)%mod)%mod;
        tre[rt<<1|1].sum=(tre[rt<<1|1].sum+tre[rt].lazy*(r-mid)%mod)%mod;
        tre[rt].lazy=0;
    }
}
void build(int l,int r,int rt){
    if(l==r){
        tre[rt].sum=a[rk[l]];
        return ;
    }
    int mid=(l+r)/2;
    build(l,mid,rt*2);
    build(mid+1,r,rt*2+1);
    pushup(rt);
    return ;
}
ll query(int l,int r,int lm,int rm,int rt){
    if(l<=lm&&r>=rm){
        return tre[rt].sum;
    }
    ll ans=0;
    int mid=(lm+rm)/2;
    pushdown(rt,lm,rm);
    if(l<=mid)ans+=query(l,r,lm,mid,rt<<1);
    if(r>mid)ans+=query(l,r,mid+1,rm,rt<<1|1);
    return ans%mod;
}
void update(int l,int r,int lm,int rm,ll num,int rt){
    if(lm>=l&&rm<=r){
        tre[rt].lazy+=num;
        tre[rt].sum=(tre[rt].sum+1ll*(rm-lm+1)*num%mod)%mod;
        return ;
    }
    int mid=(lm+rm)/2;
    pushdown(rt,lm,rm);
    if(mid>=l)update(l,r,lm,mid,num,rt<<1);
    if(mid<r)update(l,r,mid+1,rm,num,rt<<1|1);
    pushup(rt);
}
void fun(int pos,int pos1,ll num,int n){
    int x=pos;
    int y=pos1;
    int f=top[x];
    int f1=top[y];
    while(f!=f1){
        if(dep[f]>=dep[f1]){
            update(id[f],id[x],1,n,num,1);
            x=pre[f];
            f=top[x];
        }
        else{
            update(id[f1],id[y],1,n,num,1);
            y=pre[f1];
            f1=top[y];
        }
    }
    if(id[x]<=id[y]){
        update(id[x],id[y],1,n,num,1);
    }
    else{
        update(id[y],id[x],1,n,num,1);
    }
    return ;
}
ll fun1(int pos,int pos1,int n){
    ll ans=0;
    int x=pos;
    int y=pos1;
    int f=top[x];
    int f1=top[y];
    while(f!=f1){    
        if(dep[f]>=dep[f1]){
            ans+=query(id[f],id[x],1,n,1);
            x=pre[f];
            f=top[x];
        }
        else{
            ans+=query(id[f1],id[y],1,n,1);
            y=pre[f1];
            f1=top[y];
        }
    }
    if(id[x]<=id[y]){
        ans+=query(id[x],id[y],1,n,1);
    }
    else{
        ans+=query(id[y],id[x],1,n,1);
    }
    return ans%mod;
}
int main(){
    int n,m,k;
    int u,v;
    scanf("%d%d%d%lld",&n,&m,&k,&mod);
    init(n);
    for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
    }    
    dfs1(k,0,0);
    dfs2(k,k);
    build(1,n,1);
    int op,pos,pos1;
    ll num;
    for(int i=1;i<=m;i++){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%lld",&pos,&pos1,&num);
            fun(pos,pos1,num,n);
        }
        else if(op==2){
            scanf("%d%d",&pos,&pos1);
            printf("%lld
",fun1(pos,pos1,n));
        }
        else if(op==3){
            scanf("%d%lld",&pos,&num);
            update(id[pos],id[pos]+tsize[pos]-1,1,n,num,1);
        }
        else{
            scanf("%d",&pos);
            printf("%lld
",query(id[pos],id[pos]+tsize[pos]-1,1,n,1));
        }
    }
    return 0;
}

 

原文地址:https://www.cnblogs.com/Zhi-71/p/12733757.html