Luogu P3384 【模板】树链剖分

传送门~

树链剖分,顾名思义,就是把树分成链。

通过这个方法,可以优化对树上两点间路径某一点子树的修改和查询的操作,等。

流程

$dfs1()$

在这个函数中,要处理出每个节点的:

  • 深度dep[]
  • 父亲fa[]
  • 大小siz[]
  • 重儿子编号hson[]

一个节点的siz[],是包括它自己、它的儿子、它儿子的儿子……一共的节点数量。

所谓的重儿子,就是一个节点的儿子中,siz[]最大的那一个。

叶子节点没有儿子,所以也没有重儿子。

这个函数就是普通的遍历整棵树,每到一个点记录dpth[],siz[]初始值为1。

对于所连的每一条边,标记儿子的fa[]并递归$dfs$,回溯时将子树大小加给父亲的siz[],并判断是否更新重儿子hson[]。

void dfs1(int u){
    dpth[u] = dpth[fa[u]]+1;
    siz[u] = 1;
    for(int i = head[u];i;i = nxt[i]){
        int v = to[i];
        if(v == fa[u])continue;
        fa[v] = u;
        dfs1(v);
        siz[u] += siz[v];
        if(siz[v] > siz[hson[u]]) hson[u] = v;
    }
}

知道了重儿子,如果沿着树走一遍,并且每次都先走重儿子,就可以把整个树拆成从大到小的很多链!

$dfs2()$

在这个函数中,要处理出每个节点的:

  • 新编号,也就是先走重儿子的$dfs$序dfn[]
  • 新编号对应的原编号point[]
  • 所在链的顶端的点top[]

每到一个点,首先直接记录它的dfn[]、point[]、top[](有点类似father)。

如果没有重儿子,则是叶子节点,直接$return$;

如果有重儿子,则这条链还没有结束,继续递归,top[]不变;

如果除了重儿子还有其他的儿子,则每个轻儿子都是一条新链,继续递归,且新链的top[]是这个轻儿子。

void dfs2(int u,int t){
    dfn[u] = ++cnt;
    point[cnt] = u;
    top[u] = t;
    if(!hson[u])return;
    dfs2(hson[u],t);
    for(int i = head[u];i;i = nxt[i]){
        int v = to[i];
        if(v == fa[u] || v == hson[u])continue;
        dfs2(v,v);
    }
}

这样,“链”的部分就算是处理好了$-w-$

接下来,根据新的编号dfn[],建一棵线段树。

线段树需要的函数:$build$、$pushdown$、$modify$、$query$

下面来看需要支持的操作:

  1. 将树从x到y结点最短路径上所有节点的值都加上z
  2. 求树从x到y结点最短路径上所有节点的值之和
  3. 将以x为根节点的子树内所有节点值都加上z
  4. 求以x为根节点的子树内所有节点值之和

树上路径

求树上路径,听起来很像$LCA$……实际上,树链剖分的确可以求出$LCA$。

需要再写两个函数:$getmodify$$getquery$。这里以查询操作为例。

如果x,y的top[]不同,则它们一定不在同一条链上。

和倍增$LCA$的做法相似,每次将top[]较深,即所在链的顶端比较靠下的节点,跳到它的top[]的父亲,这样就上升到了上边的链。

查询这一段路径的值(dfn[top[x]],dfn[x]),并重复操作,直到x,y在同一条链上为止。

同一链上,深度较小的点一定dfn[]较小,也就是在线段树上的编号较小。通过swap,使较浅的作为x,并查询(dfn[x],dfn[y])

int getquery(int x,int y) {
    int ans = 0;
    while(top[x] != top[y]) {
        if(dpth[top[x]] < dpth[top[y]]) swap(x,y);
        ans += query(dfn[top[x]],dfn[x],1,n,1);
        x = fa[top[x]];
    }
    if(dpth[x] > dpth[y]) swap(x,y);
    ans += query(dfn[x],dfn[y],1,n,1);
    return ans;
}

子树

根据dfs序,可以知道,一棵子树内的序号是连续的。

那么,对应到线段树上,即为(dfn[x],dfn[x]+siz[x]-1)。

完整代码如下

#include<iostream>
#include<cstdio>
#include<cstring>
#define MogeKo qwq
using namespace std;
const int maxn = 1e5+10;
int n,m,rt,opt,x,y,z,mod,cnt;
int to[maxn<<1],head[maxn<<1],nxt[maxn<<1];
long long sum[maxn<<2],lazy[maxn<<2];
int dfn[maxn],dpth[maxn],siz[maxn],hson[maxn],fa[maxn],top[maxn],point[maxn];
long long w[maxn];

void add(int x,int y) {
    to[++cnt] = y;
    nxt[cnt] = head[x];
    head[x] = cnt;
}

void dfs1(int u) {
    dpth[u] = dpth[fa[u]]+1;
    siz[u] = 1;
    for(int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if(v == fa[u])continue;
        fa[v] = u;
        dfs1(v);
        siz[u] += siz[v];
        if(siz[v] > siz[hson[u]]) hson[u] = v;
    }
}

void dfs2(int u,int t) {
    dfn[u] = ++cnt;
    point[cnt] = u;
    top[u] = t;
    if(!hson[u])return;
    dfs2(hson[u],t);
    for(int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if(v == fa[u] || v == hson[u])continue;
        dfs2(v,v);
    }
}

void build(int l,int r,int now) {
    if(l == r) {
        sum[now] = w[point[l]] %mod;
        return;
    }
    int mid = l+r>>1;
    build(l,mid,now<<1),build(mid+1,r,now<<1|1);
    sum[now] = (sum[now<<1] + sum[now<<1|1]) %mod;
}

void pushdown(int l,int r,int now) {
    sum[now] += (r-l+1)*lazy[now]%mod;
    (lazy[now<<1] += lazy[now]) %=mod;
    (lazy[now<<1|1] += lazy[now]) %=mod;
    lazy[now] = 0;
}

void modify(int L,int R,int l,int r,int c,int now) {
    if(L == l && R == r) {
        lazy[now] += c;
        return;
    }
    (sum[now] += (R-L+1)*c ) %=mod;
    int mid = l+r>>1;
    if(R <= mid) modify(L,R,l,mid,c,now<<1);
    else if(L >= mid+1) modify(L,R,mid+1,r,c,now<<1|1);
    else modify(L,mid,l,mid,c,now<<1),modify(mid+1,R,mid+1,r,c,now<<1|1);
}

long long query(int L,int R,int l,int r,int now) {
    if(L == l && R == r) {
        return (sum[now]+lazy[now]*(r-l+1))%mod;
    }
    pushdown(l,r,now);
    int mid = l+r>>1;
    if(R <= mid) return query(L,R,l,mid,now<<1);
    else if(L >= mid+1) return query(L,R,mid+1,r,now<<1|1);
    else return (query(L,mid,l,mid,now<<1) + query(mid+1,R,mid+1,r,now<<1|1)) %mod;
}

void getmodify(int x,int y,int c) {
    while(top[x] != top[y]) {
        if(dpth[top[x]] < dpth[top[y]]) swap(x,y);
        modify(dfn[top[x]],dfn[x],1,n,c,1);
        x = fa[top[x]];
    }
    if(dpth[x] > dpth[y]) swap(x,y);
    modify(dfn[x],dfn[y],1,n,c,1);
}

long long getquery(int x,int y) {
    long long ans = 0;
    while(top[x] != top[y]) {
        if(dpth[top[x]] < dpth[top[y]]) swap(x,y);
        (ans += query(dfn[top[x]],dfn[x],1,n,1)) %=mod;
        x = fa[top[x]];
    }
    if(dpth[x] > dpth[y]) swap(x,y);
    (ans += query(dfn[x],dfn[y],1,n,1)) %=mod;
    return ans;
}

int main() {
    scanf("%d%d%d%d",&n,&m,&rt,&mod);
    for(int i = 1; i <= n; i++)
        scanf("%lld",&w[i]);
    for(int i = 1; i <= n-1; i++) {
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    cnt = 0;
    dfs1(rt),dfs2(rt,rt);
    build(1,n,1);
    for(int i = 1; i <= m; i++) {
        scanf("%d",&opt);
        if(opt == 1) {
            scanf("%d%d%d",&x,&y,&z);
            getmodify(x,y,z%mod);
        }
        if(opt == 2) {
            scanf("%d%d",&x,&y);
            printf("%lld
",getquery(x,y));
        }
        if(opt == 3) {
            scanf("%d%d",&x,&z);
            modify(dfn[x],dfn[x]+siz[x]-1,1,n,z%mod,1);
        }
        if(opt == 4) {
            scanf("%d",&x);
            printf("%lld
",query(dfn[x],dfn[x]+siz[x]-1,1,n,1));
        }
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/mogeko/p/11229910.html