树链剖分+线段树 [Codeforces Round #457 (Div. 2) E. Jamie and Tree]

树链剖分+线段树 Codeforces Round #457 (Div. 2) E. Jamie and Tree

题目大意:

给你一棵树,对这棵树有三种操作:

  • 1 v 表示把根节点变成 v 这个节点
  • 2 u v x 表示把含有u和v这两个节点的最小子树的所有节点都加上x
  • 3 v 查v这个所在子树的权值。

题解:

这个有点难处理的就是这个换根之后的更新一棵子树。

但是呢,注意一下这个是怎么更新子树的,这个给了你两个点 u 和 v ,很自然就可以在这两个地方做文章。

含有 (u)(v) 这两个节点的最小子树,那么先求 (LCA(u,v)) ,已知 (u=LCA(u,v)) ,那么 (u) 这个节点到根节点的这个儿子节点不加 (x),其他所有节点都要加上 (x)

那么求 (u) 这个节点到根节点的这个儿子节点呢?

  • 直接暴力找,枚举 (u) 的所有儿子,利用 (dfs) 序判断是否里面
  • 树链剖分,往上跳,找到最后跳的这个点,如果最后一个节点是 (u) 的子儿子,那么就直接是这个节点,如果不是,那么说明是重儿子节点。

怎么求这个 (LCA(u,v)) 这个我不会,看的别人的,学习一下!!!!

对于换根之后的 (u,v) 节点的 (LCA(u,v))(LCA(u,root),LCA(v,root),LCA(u,v)) 三个节点中深度最大的那个点。

最后就是分成两种情况讨论:

  • 如果 (LCA(u,v)=LCA(u,v)) 说明这个新的根节点对这个子树没有影响,那么就直接按照之前的更新,
  • 否则,找到新的 (LCA(u,v)) 到根节点的子儿子 (v),更新整棵树 (+x) ,再更新这个子儿子 (v) 的子树所有节点 (-x)

对于第三个的查询,判断一下v和root的位置,如果root在v的子树,那么分成两段来更新,否则按照原来v的子树直接更新。

这个题目的难点其实就是求 (LCA(u,v)) ,如果知道 (LCA(u,v)) 那么就很好写了。

注意一下特判根节点和 (LCA(u,v)) 是不是相同的。

#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define debug(x) cout<<"debug:"<<#x<<" = "<<x<<endl;
using namespace std;
typedef long long ll;
const int maxn = 1e5+10;
int head[maxn],nxt[maxn<<1],to[maxn<<1],cnt;
void add(int u,int v){
    ++cnt,to[cnt]=v,nxt[cnt]=head[u],head[u]=cnt;
    ++cnt,to[cnt]=u,nxt[cnt]=head[v],head[v]=cnt;
}

int id[maxn],top[maxn],tot,rk[maxn];
ll sum[maxn<<2],lazy[maxn<<2],len[maxn<<2],a[maxn];
void push_up(int id){
    sum[id]=sum[id<<1]+sum[id<<1|1];
}
void push_down(int id){
    if (lazy[id]==0) return ;
    sum[id<<1]+=lazy[id]*len[id<<1];
    sum[id<<1|1]+=lazy[id]*len[id<<1|1];
    lazy[id<<1]+=lazy[id];
    lazy[id<<1|1]+=lazy[id];
    lazy[id]=0;
}
void build(int id,int l,int r){
    len[id]=r-l+1;
    if(l==r) {
        sum[id] = a[rk[l]];
        return ;
    }
    int mid=(l+r)>>1;
    build(id<<1,l,mid);
    build(id<<1|1,mid+1,r);
    push_up(id);
}
void update(int id,int l,int r,int x,int y,ll val){
    if(x<=l&&y>=r){
        sum[id]+=len[id]*val;
        lazy[id]+=val;
        return ;
    }
    push_down(id);
    int mid=(l+r)>>1;
    if(x<=mid) update(id<<1,l,mid,x,y,val);
    if(y>mid) update(id<<1|1,mid+1,r,x,y,val);
    push_up(id);
}
ll query(int id,int l,int r,int x,int y){
    if(x<=l&&y>=r) return sum[id];
    push_down(id);
    int mid=(l+r)>>1;
    ll ans = 0;
    if(x<=mid) ans += query(id<<1,l,mid,x,y);
    if(y>mid) ans += query(id<<1|1,mid+1,r,x,y);
    return ans;
}

int fa[maxn],siz[maxn],son[maxn],dep[maxn];
void dfs1(int u,int pre,int d){
    fa[u]=pre,siz[u]=1,son[u]=0,dep[u]=d;
    for(int i=head[u];i;i=nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs1(v,u,d+1);
        siz[u]+=siz[v];
        if(!son[u]||siz[v]>siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u,int tp){
    id[u]=++tot,top[u]=tp,rk[tot]=u;
    if(!son[u]) return ;
    dfs2(son[u],tp);
    for(int i=head[u];i;i=nxt[i]){
        int v = to[i];
        if(v == fa[u]|| v== son[u]) continue;
        dfs2(v,v);
    }
}


int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x = fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}
//  x = root 
int Query(int x,int y){
    int ans = 0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans = top[x];
        x = fa[top[x]];
    }
    if(fa[ans]==y) return ans;
    return son[y];
}

int main(){
    int n,q,root = 1;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v);
    }
    dfs1(1,-1,1),dfs2(1,1),build(1,1,n);
    while(q--){
        int op;
        scanf("%d",&op);
        if(op==1){
            int v;
            scanf("%d",&v);
            root = v;
        }
        else if(op==2){
            int u,v,x;
            scanf("%d%d%d",&u,&v,&x);
            int lca1 = LCA(u,v),lca2=LCA(u,root),lca3=LCA(v,root);
            if(dep[lca1]>max(dep[lca2],dep[lca3])){
                update(1,1,n,id[lca1],id[lca1]+siz[lca1]-1,x);
            }
            else{
                int lca = dep[lca2]>dep[lca3]?lca2:lca3;
                int u = Query(root,lca);
                update(1,1,n,1,n,x);
                if(lca!=root) update(1,1,n,id[u],id[u]+siz[u]-1,-x);
            }
        }
        else{
            int v;
            ll ans = 0;
            scanf("%d",&v);
            int lca = LCA(v,root);
            if(v==root) ans = sum[1];
            else if(dep[lca]==dep[v]){
                int u = Query(root,v);
                ans = sum[1] - query(1,1,n,id[u],id[u]+siz[u]-1);
            }
            else{
                // debug("???")
                ans = query(1,1,n,id[v],id[v]+siz[v]-1);
            }
            printf("%lld
", ans);
        }
    }
}



/*
4 100
4 3 5 6
1 2
2 3
3 4
3 1
1 3
*/

原文地址:https://www.cnblogs.com/EchoZQN/p/13473694.html