Dsu on tree 学习笔记

( ext{dsu on tree}) 略解

简介

首先 ( ext{dsu on tree}) 和并查集并没有关系,其用来处理一类树上问题,一般有两个特征:

  1. 不带修改
  2. 询问与子树有关

( ext{dsu on tree}) 可以十分方便的在 (O(nlogn)) 的时间复杂度内解决。

大致思路

( ext{dsu on tree}) 利用了重链剖分中重儿子的思想来进行暴力。

例如一道例题 CF600E:求每个子树内出现次数最多的颜色之和

(O(n^2)) 暴力十分显然,但是可以发现一个性质:父子之间的信息共享,而兄弟之间的信息不共享,也就是计算完最后一个子树的信息后,可以不用清空,其信息可以保留下来继续给父亲使用。

所以我们想到使最后一个遍历的子树尽可能大,也就是 重儿子

算法流程

设当前求到 (u) 的答案 (ans_u),算法大致分为 (5) 步:

  1. 计算 轻儿子 (v)(ans_v)
  2. 计算 (u) 重儿子 (son_u)(ans_{son_u}),并将 (son_u) 的信息保留继续使用
  3. 再暴力计算每个轻儿子的信息
  4. 更新 (ans_u)
  5. 如果 (u) 不为重儿子,则暴力删去 (u) 的信息

”暴力计算 (v)“ 指将以 (v) 为根的子树遍历一遍计算信息(也可能因题目而异吧)

复杂度

首先有一个重要的性质:一个节点到根路径上的轻边数不超过 (logn),证明:

由轻重儿子的性质可知:对于 (u) 的任意轻儿子 (v)(siz_v leq frac{siz_u}{2})

因此每经过一条轻边 (siz/2),那么任意点开始往叶子节点走经过轻边数量最多不超过 (logn)

得证

再考虑每个点 (v) 会被计算多少次,按其到根的路径上的轻/重边分为两类讨论:

  1. 对于每条轻边,都需要单独计算一次 (v) 的信息,由以上性质知不超过 (logn)
  2. 对于 (v) 到根路径上的每条重边,是不需要再计算 (v)

所以对于节点 (v),一共会被计算 (logn + 1) 次((1) 为计算 (ans_v)

综上,若计算一个点的信息为 (O(1)),则该算法时间复杂度为 (O(nlogn))

例题

CF600E

第一次打 (Code) 有点丑......

Code

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

#define N 100000

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)

#define ll long long

void read(int &x) {
    char ch = getchar(); x = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}

struct EDGE { int next, to; } edge[N << 1];

int head[N + 1], col[N + 1], h[N + 1], sz[N + 1], son[N + 1], las[N + 1];

ll ans[N + 1];

int n;

int cnt_edge = 1;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
void Link(int u, int v) { Add(u, v), Add(v, u); }

void Dfs1(int u, int la) {
    sz[u] = 1, son[u] = 0;
    Fo(i, u) if (i != la) {
        Dfs1(edge[i].to, i ^ 1);
        if (sz[edge[i].to] > sz[son[u]])
            son[u] = edge[i].to;
        sz[u] += sz[edge[i].to];
    }
}

int max_h = 0;

ll sum = 0;

void Add1(int c, int d) {
    h[c] += d;
    if (h[c] > max_h) max_h = h[c], sum = c;
    else if (h[c] == max_h) sum += c;
}

void Dfs3(int u, int la, int d) {
    Add1(col[u], d);
    Fo(i, u) if (i != la)
        Dfs3(edge[i].to, i ^ 1, d);
}

void Dfs2(int u, int fa, int opt) {
    int v = 0;
    Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
        Dfs2(v, u, 1);
    if (son[u]) Dfs2(son[u], u, 0);
    Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
        Dfs3(v, i ^ 1, 1);
    Add1(col[u], 1);
    ans[u] = sum;
    if (opt) {
        Fo(i, u) if ((v = edge[i].to) != fa)
            Dfs3(v, i ^ 1, -1);
        Add1(col[u], -1);
        sum = max_h = 0;
    }
}

int main() {
    read(n);
    fo(i, 1, n) read(col[i]);
    for (int i = 1, x, y; i < n; i ++)
        read(x), read(y), Link(x, y);

    Dfs1(1, 0);
    Dfs2(1, 0, 0);
    
    fo(i, 1, n) printf("%lld ", ans[i]);

    return 0;
}


CF741D

Solution

由回文串的性质可知:区间内之多只有一个字符出现奇数次。

借此可以将统计出现次数转化为异或,可以用大小为 (2^{22}) 的状态表示从根开始的路径上每个字符出现次数的奇偶性。

(dis_{u}) 为从根到 (x) 的路径上字符的状态,那么任意路径 ((u, v)) 的字符状态就可以表示为 (dis_{(u, v)} = dis_u oplus dis_v oplus dis_{lca} oplus dis_{lca}),由异或的性质可知即为 (dis_u oplus dis_v),而距离就是 (dep_u + dep_v - 2dep_{lca})

所以只需要用大小 (2^{22}) 的桶存下每个状态的最深深度,(O(22)) 可以求出一个点对答案的贡献,每个点 (u)(ans_u) 为其子树 (ans) 与经过 (u) 最长符合条件路径的最大值。

剩下的就基本是 ( ext{dsu on tree}) 的模板了。

Code

#include <cstdio>

using namespace std;

#define N 500000
#define M 22
#define inf 10000 

#define fo(i, x, y) for(int i = x, end_##i = y; i <= end_##i; i ++)
#define fd(i, x, y) for(int i = x, end_##i = y; i >= end_##i; i --)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)

void read(int &x) {
    char ch = getchar(); x = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}

struct EDGE { int next, to; } edge[N << 1];

int head[N + 1], col[N + 1], f[1 << M], d[N + 1], sz[N + 1], son[N + 1], fa[N + 1], c[N + 1], ans[N + 1];

int n;

int cnt_edge = 0;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }

int max(int x, int y) { return x > y ? x : y; }

void Init() {
    d[1] = 1;
    fo(i, 2, n) c[i] = c[fa[i]] ^ (1 << col[i]), d[i] = d[fa[i]] + 1;
    fd(i, n, 1) {
        if (++ sz[i] > sz[son[fa[i]]])
            son[fa[i]] = i;
        sz[fa[i]] += sz[i];
    }
    fo(i, 2, n) if (i != son[fa[i]])
        Add(fa[i], i);
    fo(i, 0, (1 << M) - 1) f[i] = -inf;
}

int Get_d(int x) {
    int dep = f[x];
    fo(i, 0, M - 1)
        dep = max(dep, f[x ^ (1 << i)]);
    return dep;
}

int Dfs1(int u) {
    int dep = Get_d(c[u]) + d[u];
    Fo(i, u) dep = max(dep, Dfs1(edge[i].to));
    if (son[u]) dep = max(dep, Dfs1(son[u]));
    return dep;
}

void Updata(int x, int dep) { f[x] = max(f[x], dep); }

void Dfs2(int u) {
    Updata(c[u], d[u]);
    Fo(i, u) Dfs2(edge[i].to);
    if (son[u]) Dfs2(son[u]);
}

void Back(int x) { f[x] = -inf; }

void Dfs3(int u) {
    Back(c[u]);
    Fo(i, u) Dfs3(edge[i].to);
    if (son[u]) Dfs3(son[u]);
}

void Solve(int u, int opt) {
    ans[u] = 0;
    Fo(i, u) Solve(edge[i].to, 1), ans[u] = max(ans[u], ans[edge[i].to]);
    if (son[u]) Solve(son[u], 0), ans[u] = max(ans[u], ans[son[u]]);
    ans[u] = max(ans[u], Get_d(c[u]) - d[u]);
    Updata(c[u], d[u]);
    Fo(i, u)
        ans[u] = max(ans[u], Dfs1(edge[i].to) - (d[u] << 1)), Dfs2(edge[i].to);
    if (opt) Dfs3(u);
}

int main() {
    read(n);
    fo(i, 2, n)
        read(fa[i]), col[i] = getchar() - 'a';

    Init();

    Solve(1, 0);

    fo(i, 1, n) printf("%d ", ans[i]);

    return 0;
}
原文地址:https://www.cnblogs.com/zhouzj2004/p/13934448.html