luogu 3384 【模板】树链剖分

题目大意:

一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

思路:

打个模板调了一年

结果是因为中间结果没取模

耗了2h

zz一样

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cmath>
  4 #include<cstdlib>
  5 #include<cstring>
  6 #include<algorithm>
  7 #include<vector>
  8 #include<queue>
  9 #define inf 2139062143
 10 #define ll long long
 11 #define MAXN 101010
 12 using namespace std;
 13 inline int read()
 14 {
 15     int x=0,f=1;char ch=getchar();
 16     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
 17     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
 18     return x*f;
 19 }
 20 int n,rt,Cnt,nxt[MAXN*2],fst[MAXN],to[MAXN*2],val[MAXN],MOD;
 21 int fa[MAXN],dep[MAXN],bl[MAXN],cnt[MAXN],hsh[MAXN];
 22 struct data{int mx,sum,l,r,tag;}tr[MAXN*3];
 23 void add(int u,int v) {nxt[++Cnt]=fst[u],fst[u]=Cnt,to[Cnt]=v;}
 24 void build(int x)
 25 {
 26     for(int i=fst[x];i;i=nxt[i])
 27     {
 28         if(to[i]==fa[x]) continue;
 29         dep[to[i]]=dep[x]+1;
 30         fa[to[i]]=x;
 31         build(to[i]);
 32         cnt[x]+=cnt[to[i]];
 33     }
 34     cnt[x]++;
 35 }
 36 void Build(int x,int chn)
 37 {
 38     int hvs=0;hsh[x]=++Cnt,bl[x]=chn;
 39     for(int i=fst[x];i;i=nxt[i])
 40         if(fa[x]!=to[i]&&cnt[hvs]<cnt[to[i]]) hvs=to[i];
 41     if(!hvs) return ;
 42     Build(hvs,chn);
 43     for(int i=fst[x];i;i=nxt[i])
 44         if(fa[x]!=to[i]&&hvs!=to[i]) Build(to[i],to[i]);
 45 }
 46 void s_build(int k,int l,int r)
 47 {
 48     tr[k].l=l,tr[k].r=r,tr[k].tag=0;
 49     if(l==r) return ;
 50     int mid=(l+r)>>1;
 51     s_build(k<<1,l,mid);
 52     s_build(k<<1|1,mid+1,r);
 53 }
 54 void pushdown(int k)
 55 {
 56     tr[k<<1].tag+=tr[k].tag,tr[k<<1|1].tag+=tr[k].tag;
 57     (tr[k<<1].sum+=tr[k].tag*(tr[k<<1].r-tr[k<<1].l+1))%=MOD;
 58     (tr[k<<1|1].sum+=tr[k].tag*(tr[k<<1|1].r-tr[k<<1|1].l+1))%=MOD;
 59     tr[k].tag=0;
 60 }
 61 void upd(int k,int a,int b,int x)
 62 {
 63     int l=tr[k].l,r=tr[k].r;
 64     if(l==a&&r==b) {tr[k].tag+=x,(tr[k].sum+=(r-l+1)*x)%=MOD;return ;}
 65     if(tr[k].tag) pushdown(k);
 66     int mid=(l+r)>>1;
 67     if(b<=mid) upd(k<<1,a,b,x);
 68     else if(a>mid) upd(k<<1|1,a,b,x);
 69     else {upd(k<<1,a,mid,x);upd(k<<1|1,mid+1,b,x);}
 70     tr[k].sum=(tr[k<<1].sum+tr[k<<1|1].sum)%MOD;
 71 }
 72 int q_sum(int k,int a,int b)
 73 {
 74     int l=tr[k].l,r=tr[k].r;
 75     if(l==a&&r==b) return tr[k].sum;
 76     if(tr[k].tag) pushdown(k);
 77     int mid=(l+r)>>1;
 78     if(b<=mid) return q_sum(k<<1,a,b)%MOD;
 79     else if(a>mid) return q_sum(k<<1|1,a,b)%MOD;
 80     else return (q_sum(k<<1,a,mid)%MOD+q_sum(k<<1|1,mid+1,b)%MOD)%MOD;
 81 }
 82 int main()
 83 {
 84     int a,b,c,res,T,x;
 85     n=read(),T=read(),rt=read(),MOD=read();
 86     for(int i=1;i<=n;i++) val[i]=(read())%MOD;dep[rt]=1,fa[rt]=rt;
 87     for(int i=1;i<n;i++) {a=read(),b=read();add(a,b);add(b,a);}
 88     build(rt);Cnt=0;
 89     Build(rt,rt);
 90     s_build(1,1,n);
 91     for(int i=1;i<=n;i++) upd(1,hsh[i],hsh[i],val[i]);
 92     while(T--)
 93     {
 94         x=read();
 95         if(x==1)
 96         {
 97             a=read(),b=read(),c=(read())%MOD;
 98             while(bl[a]!=bl[b])
 99             {
100                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
101                 upd(1,hsh[bl[a]],hsh[a],c);
102                 a=fa[bl[a]];
103             }
104             upd(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]),c);
105         }
106         if(x==2)
107         {
108             res=0,a=read(),b=read();
109             while(bl[a]!=bl[b])
110             {
111                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
112                 (res+=q_sum(1,hsh[bl[a]],hsh[a]))%=MOD;
113                 a=fa[bl[a]];
114             }
115             (res+=q_sum(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b])))%=MOD;
116             printf("%d
",res);
117         }
118         if(x==3)
119         {
120             a=read(),(c=read())%MOD;
121             upd(1,hsh[a],hsh[a]+cnt[a]-1,c);
122         }
123         if(x==4)
124         {
125             a=read();
126             res=q_sum(1,hsh[a],hsh[a]+cnt[a]-1);
127             printf("%d
",res);
128         }
129     }
130 }
View Code
原文地址:https://www.cnblogs.com/yyc-jack-0920/p/7955547.html