树链剖分

推荐一篇博客

我的模板(洛谷p3384

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<ctime>
#include<vector>
#include<set>
#include<map>
#include<stack>
using namespace std;
long long n,m,r,p;
long long d[2100000],a[110000],col[2100000];
vector<long long>v[110000];
long long son[110000],size[110000],fin[110000],acc[110000];
long long fa[110000],no[110000],dep[110000];
long long cnt;
inline void read(long long &x){
      long long f=1;x=0;
      char s=getchar();
      while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
      while(s>='0'&&s<='9'){x=x*10+(s-'0');s=getchar();}
      x*=f;
}
inline void dfs1(long long u,long long f){
      long long i,j,k,maxn=-1;
      fa[u]=f;
      size[u]=1;
      for(i=0;i<v[u].size();i++)
         if(v[u][i]!=f){
             long long x=v[u][i];
             dep[x]=dep[u]+1;
             dfs1(x,u);
            size[u]+=size[x];
            if(size[x]>maxn){
              maxn=size[x];
              son[u]=x;
            }
         }
}
inline void dfs2(long long u,long long ac){
      no[u]=++cnt;
      fin[u]=no[u]+size[u]-1;
      acc[u]=ac;
      if(!son[u])return;
      dfs2(son[u],ac);
      long long i,j,k;
      for(i=0;i<v[u].size();i++)
        if(v[u][i]!=fa[u]&&v[u][i]!=son[u]){
              long long x=v[u][i];
              dfs2(x,x);
          }
}
inline void build(long long le,long long ri,long long wh,long long k,long long pl){
      d[wh]=(d[wh]+=k)%p;
      if(le==ri)return;
      long long mid=(le+ri)>>1;
      if(mid>=pl)
        build(le,mid,wh*2,k,pl);
        else
          build(mid+1,ri,wh*2+1,k,pl);
}
inline void update(long long le,long long ri,long long x,long long y,long long wh,long long k){
      if(le>=x&&ri<=y){
          col[wh]=(col[wh]+=k)%p;
          d[wh]=(d[wh]+=(ri-le+1)*k)%p;
          return;
      }
      long long mid=(le+ri)>>1;
      if(col[wh]){
          col[wh*2]=(col[wh*2]+=col[wh])%p;
          col[wh*2+1]=(col[wh*2+1]+=col[wh])%p;
          d[wh*2]=(d[wh*2]+=(mid-le+1)*col[wh])%p;
          d[wh*2+1]=(d[wh*2+1]+=(ri-mid)*col[wh])%p;
        col[wh]=0;
      }
      if(mid>=x)update(le,mid,x,y,wh*2,k);
      if(mid<y)update(mid+1,ri,x,y,wh*2+1,k);
      d[wh]=(d[wh*2]+d[wh*2+1])%p;
}
inline long long que(long long le,long long ri,long long x,long long y,long long wh){
      if(le>=x&&ri<=y){
          return d[wh]%p;
      }
      long long mid=(le+ri)>>1,ans=0;
      if(col[wh]){
          col[wh*2]=(col[wh*2]+=col[wh])%p;
          col[wh*2+1]=(col[wh*2+1]+=col[wh])%p;
          d[wh*2]=(d[wh*2]+=(mid-le+1)*col[wh])%p;
          d[wh*2+1]=(d[wh*2+1]+=(ri-mid)*col[wh])%p;
        col[wh]=0;
      }
      if(mid>=x)(ans+=que(le,mid,x,y,wh*2))%p;
      if(mid<y)(ans+=que(mid+1,ri,x,y,wh*2+1))%p;
      d[wh]=(d[wh*2]+d[wh*2+1])%p;
      return ans;
}
inline void solve1(long long x,long long y,long long z){
      z%=p;
      while(acc[x]!=acc[y]){
          if(dep[acc[x]]<dep[acc[y]])swap(x,y);
          update(1,n,no[acc[x]],no[x],1,z);
          x=fa[acc[x]];
      }
      if(no[x]>no[y])swap(x,y);
      update(1,n,no[x],no[y],1,z);
}
inline void solve2(long long x,long long y){
      long long ans=0;
      while(acc[x]!=acc[y]){
          if(dep[acc[x]]<dep[acc[y]])swap(x,y);
          ans=(ans+=que(1,n,no[acc[x]],no[x],1))%p;
          x=fa[acc[x]];
      }
      if(no[x]>no[y])swap(x,y);
      ans=(ans+=que(1,n,no[x],no[y],1))%p;
      printf("%lld ",ans%p);
}
inline void solve3(long long x,long long y){
      update(1,n,no[x],fin[x],1,y);
}
inline void solve4(long long x){
      printf("%lld ",que(1,n,no[x],fin[x],1)%p);
}
int main()
{     long long i,j,k,x,y,z;
      read(n),read(m),read(r),read(p);
      for(i=1;i<=n;i++){
        read(a[i]);
      }
      for(i=1;i<n;i++){
          read(x),read(y);
          v[x].push_back(y);
          v[y].push_back(x);
      }
      dep[r]=1;
      dfs1(r,-1);
      dfs2(r,r);
      for(i=1;i<=n;i++){
        build(1,n,1,a[i],no[i]);
      }
      for(i=1;i<=m;i++){
          read(k);
          if(k==1){
              read(x),read(y),read(z);
              solve1(x,y,z);
          }
          else if(k==2){
              read(x),read(y);
              solve2(x,y);
          }
          else if(k==3){
              read(x),read(y);
              solve3(x,y);
          }
          else {
              read(x);
              solve4(x);
          }
      }
      return 0;
}

原文地址:https://www.cnblogs.com/yzxverygood/p/8457361.html