树上启发式合并练习题——树上子树颜色询问

题目描述

给你一个包含 (n(1 le n le 10^6)) 个节点的树,节点编号从 (1)(n),根节点的编号为 (1)。每一个节点都有一个颜色,我们用 (c_i) 来表示节点 (i) 的颜色。
接下来有 (m(1 le m le 10^6)) 次询问,每一次询问都会给你两个整数 (u)(c)(1 le u,c le n)),对于每一次询问,你需要回答:以节点 (u) 为根节点的子树中颜色为 (c) 的节点数量。

题解

大多数人都知道 DSU(并查集,Disjoint Set Union)但是什么是 “dsu on tree”(树上启发式合并,直译为“书上的并查集”)?

什么是树上启发式合并(dsu on tree)?

使用 dsu on tree 我们可以回答如下的问题:

(O(n log n)) 时间复杂度内计算所有的节点 v 的子树中存在多少个点满足某一性质。

所以对于这道问题我们就是求解:

给你一棵树,每一个节点都有一个颜色。问题是询问 以节点 v 为根的子树中存在多少个点的颜色为 c

暴力解法 (O(n^2))

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
    int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
void add(int u, int p, int x) {
    cnt[ c[u] ] += x;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p) add(v, u, x);
    }
}
void dfs(int u, int p) {
    add(u, p, 1);
    for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
        int id = *it;
        query[id].ans = cnt[ query[id].c ];
    }
    add(u, p, -1);
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p) dfs(v, u);
    }
}
int main() {
    cin >> n;
    for (int i = 1; i <= n; i ++) cin >> c[i];
    for (int i = 1; i < n; i ++) {
        int a, b;
        cin >> a >> b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    cin >> m;
    for (int i = 0; i < m; i ++) {
        int u, c;
        cin >> u >> query[i].c;
        qid[u].push_back(i);
    }
    dfs(1, -1);
    for (int i = 0; i < m; i ++)
        cout << query[i].ans << endl;
    return 0;
}

1. 基于dsu on tree的解法1 (O(n log^2 n))

这个解法采用了 dsu on tree 的思想,将每科子树对应的颜色和数量都存在一个 map 中。父节点复用其重儿子的 map。时间复杂度为遍历的节点数 (n log n) 乘以 map 中获得每个元素的时间 (log n) 等于 (O(n log^2 n))

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
    int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
map<int, int> mp[maxn];
int mpcnt, mpid[maxn];
void getsz(int u, int p) {  // 计算sz[u] -- 以u为根节点的子树大小(包含节点个数)
    sz[u] ++;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p) getsz(v, u);
    }
}
void dfs(int u, int p) {
    int mx = -1, bigSon = -1;   // mx表示重儿子的sz,bigSon表示重儿子的编号
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p) {
            dfs(v, u);
            if (sz[v] > mx) {
                mx = sz[v];
                bigSon = v;
            }
        }
    }
    if (bigSon == -1) // 叶子节点
        mpid[u] = ++ mpcnt;
    else
        mpid[u] = mpid[bigSon];
    mp[ mpid[u] ][ c[u] ] ++;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && v != bigSon) {
            for (map<int, int>::iterator it2 = mp[ mpid[v] ].begin(); it2 != mp[ mpid[v] ].end(); it2 ++) {
                pair<int, int> x = *it2;
                mp[ mpid[u] ][ x.first ] += x.second;
            }
            mp[ mpid[v] ].clear();
        }
    }
    for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
        int id = *it;
        query[id].ans = mp[ mpid[u] ][ query[id].c ];
    }
}
int main() {
    cin >> n;
    for (int i = 1; i <= n; i ++) cin >> c[i];
    for (int i = 1; i < n; i ++) {
        int a, b;
        cin >> a >> b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    cin >> m;
    for (int i = 0; i < m; i ++) {
        int u, c;
        cin >> u >> query[i].c;
        qid[u].push_back(i);
    }
    dfs(1, -1);
    for (int i = 0; i < m; i ++)
        cout << query[i].ans << endl;
    return 0;
}

2. 基于dsu on tree的解法2 (O(n log n))

方法2使用vector代替map,公用一个cnt数组,时间复杂度降到 (O(n log n))

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
    int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
vector<int> vec[maxn];
int veccnt, vecid[maxn];
void getsz(int u, int p) {  // 计算sz[u] -- 以u为根节点的子树大小(包含节点个数)
    sz[u] ++;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p) getsz(v, u);
    }
}
void dfs(int u, int p, bool keep) {
    int mx = -1, bigSon = -1;   // mx表示重儿子的sz,bigSon表示重儿子的编号
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && sz[v] > mx) {
            mx = sz[v];
            bigSon = v;
        }
    }
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && v != bigSon) dfs(v, u, false);
    }
    if (bigSon == -1) // 叶子节点
        vecid[u] = ++ veccnt;
    else {
        dfs(bigSon, u, true);
        vecid[u] = vecid[bigSon];
    }
    vec[ vecid[u] ].push_back(u);
    cnt[ c[u] ] ++;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && v != bigSon) {
            for (vector<int>::iterator it2 = vec[ vecid[v] ].begin(); it2 != vec[ vecid[v] ].end(); it2 ++) {
                int x = *it2;
                vec[ vecid[u] ].push_back(x);
                cnt[ c[x] ] ++;
            }
            vec[ vecid[v] ].clear();
        }
    }
    for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
        int id = *it;
        query[id].ans = cnt[ query[id].c ];
    }
    if (!keep) {    // 需要还原
        for (vector<int>::iterator it = vec[ vecid[u] ].begin(); it != vec[ vecid[u] ].end(); it ++) {
            cnt[ c[*it] ] --;
        }
    }
}
int main() {
    cin >> n;
    for (int i = 1; i <= n; i ++) cin >> c[i];
    for (int i = 1; i < n; i ++) {
        int a, b;
        cin >> a >> b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    cin >> m;
    for (int i = 0; i < m; i ++) {
        int u, c;
        cin >> u >> query[i].c;
        qid[u].push_back(i);
    }
    dfs(1, -1, false);
    for (int i = 0; i < m; i ++)
        cout << query[i].ans << endl;
    return 0;
}

3. 轻儿子-重儿子分解形式 (O(n log n))

这种格式开了一个 bool 类型的 big 数组,(big[u]) 用于标记当前节点 (u) 是不是某一个节点的重儿子,重儿子不需要还原。 这一步操作真的非常神奇!
虽然都是dsu on tree的实现,但是这种方式比前两种方式要更省空间(省的不知道哪里去了)。

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
    int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
bool big[maxn];
void add(int u, int p, int x) {
    cnt[ c[u] ] += x;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && !big[v])
            add(v, u, x);
    }
}
void dfs(int u, int p, bool keep) {
    int mx = -1, bigSon = -1;   // mx表示重儿子的sz,bigSon表示重儿子的编号
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && sz[v] > mx) {
            mx = sz[v];
            bigSon = v;
        }
    }
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = *it;
        if (v != p && v != bigSon) dfs(v, u, false);
    }
    if (bigSon != -1) {
        dfs(bigSon, u, true);
        big[bigSon] = true;
    }
    add(u, p, 1);
    for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
        int id = *it;
        query[id].ans = cnt[ query[id].c ];
    }
    if (bigSon != -1)
        big[bigSon] = 0;
    if (!keep)
        add(u, p, -1);
}
int main() {
    cin >> n;
    for (int i = 1; i <= n; i ++) cin >> c[i];
    for (int i = 1; i < n; i ++) {
        int a, b;
        cin >> a >> b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    cin >> m;
    for (int i = 0; i < m; i ++) {
        int u, c;
        cin >> u >> query[i].c;
        qid[u].push_back(i);
    }
    dfs(1, -1, false);
    for (int i = 0; i < m; i ++)
        cout << query[i].ans << endl;
    return 0;
}

参考资料:https://codeforces.com/blog/entry/44351

原文地址:https://www.cnblogs.com/quanjun/p/13916321.html