虚树学习笔记

虚树学习笔记

以消耗战为例

显然可以树形dp, 但时间复杂度爆炸

观察发现(sum k)的值不是很大,假设只有两个点x, y,它们的公共祖先lca, 树形dp就像分别枚举割断它们到lca的每一条边,事实上我们一下子(ans = min(mn[lca], mn[x] + mn[y]))就可以算出来,这是因为他们之间有大量的无用的点

所以建一棵虚树来保留对答案可能有影响的关键点,询问点和一些lca

具体来说就是类似维护极右链似的, 每次把lca搞到栈里,还是看代码吧

还需要多做些题理解一下

#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 1005000;
template <typename T>
void read(T &x) {
    x = 0; bool f = 0;
    char c = getchar();
    for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
    for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
    if (f) x=-x;
}
int n, m;
int h[N], ne[N], to[N];
int w[N], dep[N], tot;
inline void add(int x, int y, int z) {
    ne[++tot] = h[x], to[tot] = y;
    w[tot] = z, h[x] = tot;
}
int siz[N], f[N], son[N];
int Top[N], a[N], s[N], top;
ll mn[N];
void dfs1(int x, int fa) {
    siz[x] = 1, f[x] = fa, dep[x] = dep[fa] + 1;
    for (int i = h[x]; i; i = ne[i]) {
        int y = to[i]; if (y == fa) continue;
        mn[y] = min(mn[x], (ll)w[i]);
        dfs1(y, x), siz[x] += siz[y];
        if (siz[y] > siz[son[x]]) son[x] = y;
    }
}

int dfn[N], num;
void dfs2(int x, int topf) {
    Top[x] = topf, dfn[x] = ++num;
    if (!son[x]) return;
    dfs2(son[x], topf);
    for (int i = h[x]; i; i = ne[i]) 
        if (!dfn[to[i]]) dfs2(to[i], to[i]);
}

int Lca(int x, int y) {
    while (Top[x] != Top[y]) {
        if (dep[Top[x]] < dep[Top[y]]) swap(x, y);
        x = f[Top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

bool cmp(int a, int b) {
    return dfn[a] < dfn[b];
}

vector <int> v[N];
inline void add_e(int x,int y) {
    v[x].push_back(y);
}

void ins(int x) {
    if (top == 1) return (void)(s[++top] = x);
    int lca = Lca(x, s[top]); 
    if (lca == s[top]) return;
    while (top > 1 && dfn[s[top-1]] >= dfn[lca]) add_e(s[top-1], s[top]), top--;
    if (lca != s[top]) add_e(lca, s[top]), s[top] = lca;
    s[++top] = x;
}

ll dp(int x) {
    if (!v[x].size()) return mn[x];
    ll sum = 0;
    for (int i = 0;i < v[x].size(); i++)  sum += dp(v[x][i]);
    v[x].clear(); return min(sum, mn[x]);
}

int main() {
    read(n);
    for (int i = 1;i < n; i++) {
        int x, y, z; read(x), read(y), read(z);
        add(x, y, z); add(y, x, z);
    }
    mn[1] = 1ll << 50, dfs1(1, 0), dfs2(1, 1);
    read(m);
    while (m--) {
        int k; read(k);
        for (int i = 1;i <= k; i++) read(a[i]);
        sort(a + 1, a + k + 1, cmp);
        s[top = 1] = 1;
        for (int i = 1;i <= k; i++) ins(a[i]);
        while (top > 0) add_e(s[top - 1], s[top]), top--;
        printf ("%lld
", dp(1));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Hs-black/p/12271692.html