重链剖分

不会,先抄抄看。

来自oiwiki

我们先给出一些定义:

fa 表示节点 在树上的父亲。
dep 表示节点 在树上的深度。
siz 表示节点 的子树的节点个数。
son 表示节点 的 重儿子 。
top 表示节点 所在 重链 的顶部节点(深度最小)。
tid 表示节点 的 时间戳 ,也是其在线段树中的编号。
rnk 表示时间戳所对应的节点编号,有 rnk(top(x))=x 。
我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 ,第二次 DFS 求出 。

给出一种代码实现:

//fa,dep,siz,son
void dfs1(int o, int fat) {
  son[o] = -1;
  siz[o] = 1;
  for (int j = h[o]; j; j = nxt[j])
    if (!dep[p[j]]) {
      dep[p[j]] = dep[o] + 1;
      fa[p[j]] = o;
      dfs1(p[j], o);
      siz[o] += siz[p[j]];
      if (son[o] == -1 || siz[p[j]] > siz[son[o]]) son[o] = p[j];
    }
}
//top,tid,rnk
void dfs2(int o, int t) {
  top[o] = t;
  cnt++;
  tid[o] = cnt;
  rnk[cnt] = o;
  if (son[o] == -1) return;
  dfs2(son[o], t);  //优先对重儿子进行dfs,可以保证同一条重链上的点时间戳连续
  for (int j = h[o]; j; j = nxt[j])
    if (p[j] != son[o] && p[j] != fa[o]) dfs2(p[j], p[j]);
}

对树求重边优先遍历,同一重链中的节点的dfn序是连续的。

每次向下走一条轻边,子树大小至少降低一半。所以一条路径上至多有log条重链。

维护路径:
TREE - PATH - SUM(u, v)
while u,v 不在同一条链上
if u 所在链的链顶的深度小于 v 所在链的链顶的深度
swap(u, v)
将 u 到 u 所在链的链顶 之间的结点权值求和,累加到计数器中
u = u 所在链链顶的父节点
将 u, v 之间的结点的权值求和累加,返回计数器的值

维护子树:
好像跟树剖没有关系,每个节点记录dfs返回这个节点时的最后一个dfs位置,那么这个区间就是整棵子树。

求lca:
不在同一条链上就不断跳深度大的,在同一条链上返回深度小的。

#include <algorithm>
#include <cstdio>
#include <cstring>
#define lc o << 1
#define rc o << 1 | 1
const int maxn = 60010;
const int inf = 2e9;
int n, a, b, w[maxn], q, u, v;
int cur, h[maxn], nxt[maxn], p[maxn];
int siz[maxn], top[maxn], son[maxn], dep[maxn], fa[maxn], tid[maxn], rnk[maxn], cnt;
char op[10];
inline void add_edge(int x, int y) {
    cur++;
    nxt[cur] = h[x];
    h[x] = cur;
    p[cur] = y;
}
struct SegTree {
    int sum[maxn * 4], maxx[maxn * 4];
    void build(int o, int l, int r) {
        if (l == r) {
            sum[o] = maxx[o] = w[rnk[l]];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc, l, mid);
        build(rc, mid + 1, r);
        sum[o] = sum[lc] + sum[rc];
        maxx[o] = std::max(maxx[lc], maxx[rc]);
    }
    int query1(int o, int l, int r, int ql, int qr)  // max
    {
        if (l > qr || r < ql) return -inf;
        if (ql <= l && r <= qr) return maxx[o];
        int mid = (l + r) >> 1;
        return std::max(query1(lc, l, mid, ql, qr), query1(rc, mid + 1, r, ql, qr));
    }
    int query2(int o, int l, int r, int ql, int qr)  // sum
    {
        if (l > qr || r < ql) return 0;
        if (ql <= l && r <= qr) return sum[o];
        int mid = (l + r) >> 1;
        return query2(lc, l, mid, ql, qr) + query2(rc, mid + 1, r, ql, qr);
    }
    void update(int o, int l, int r, int x, int t) {
        if (l == r) {
            maxx[o] = sum[o] = t;
            return;
        }
        int mid = (l + r) >> 1;
        if (x <= mid)
            update(lc, l, mid, x, t);
        else
            update(rc, mid + 1, r, x, t);
        sum[o] = sum[lc] + sum[rc];
        maxx[o] = std::max(maxx[lc], maxx[rc]);
    }
} st;
void dfs1(int o, int fat) {
    son[o] = -1;
    siz[o] = 1;
    for (int j = h[o]; j; j = nxt[j])
        if (!dep[p[j]]) {
            dep[p[j]] = dep[o] + 1;
            fa[p[j]] = o;
            dfs1(p[j], o);
            siz[o] += siz[p[j]];
            if (son[o] == -1 || siz[p[j]] > siz[son[o]]) son[o] = p[j];
        }
}
void dfs2(int o, int t) {
    top[o] = t;
    cnt++;
    tid[o] = cnt;
    rnk[cnt] = o;
    if (son[o] == -1) return;
    dfs2(son[o], t);
    for (int j = h[o]; j; j = nxt[j])
        if (p[j] != son[o] && p[j] != fa[o]) dfs2(p[j], p[j]);
}
int querymax(int x, int y) {
    int ret = -inf, fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] >= dep[fy])
            ret = std::max(ret, st.query1(1, 1, n, tid[fx], tid[x])), x = fa[fx];
        else
            ret = std::max(ret, st.query1(1, 1, n, tid[fy], tid[y])), y = fa[fy];
        fx = top[x];
        fy = top[y];
    }
    if (x != y) {
        if (tid[x] < tid[y])
            ret = std::max(ret, st.query1(1, 1, n, tid[x], tid[y]));
        else
            ret = std::max(ret, st.query1(1, 1, n, tid[y], tid[x]));
    } else
        ret = std::max(ret, st.query1(1, 1, n, tid[x], tid[y]));
    return ret;
}
int querysum(int x, int y) {
    int ret = 0, fx = top[x], fy = top[y];
    while (fx != fy) {
        if (dep[fx] >= dep[fy])
            ret += st.query2(1, 1, n, tid[fx], tid[x]), x = fa[fx];
        else
            ret += st.query2(1, 1, n, tid[fy], tid[y]), y = fa[fy];
        fx = top[x];
        fy = top[y];
    }
    if (x != y) {
        if (tid[x] < tid[y])
            ret += st.query2(1, 1, n, tid[x], tid[y]);
        else
            ret += st.query2(1, 1, n, tid[y], tid[x]);
    } else
        ret += st.query2(1, 1, n, tid[x], tid[y]);
    return ret;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
        scanf("%d%d", &a, &b), add_edge(a, b), add_edge(b, a);
    for (int i = 1; i <= n; i++) scanf("%d", w + i);
    dep[1] = 1;
    dfs1(1, -1);
    dfs2(1, 1);
    st.build(1, 1, n);
    scanf("%d", &q);
    while (q--) {
        scanf("%s%d%d", op, &u, &v);
        if (!strcmp(op, "CHANGE")) st.update(1, 1, n, tid[u], v);
        if (!strcmp(op, "QMAX")) printf("%d
", querymax(u, v));
        if (!strcmp(op, "QSUM")) printf("%d
", querysum(u, v));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Yinku/p/11309569.html