“科大讯飞杯”第18届上海大学程序设计联赛春季赛暨高校网络友谊赛 G 血压游戏

[血压游戏] (https://ac.nowcoder.com/acm/contest/5278/G)

神奇的tag数组...,巧妙弥补了高度损失。

方法一:dsu on tree

类似长链剖分,不过是用unordered_map 来维护高度相关信息,swap复杂度是O(1)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "33[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "33[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }
const int N = 200000 + 5;
int head[N], ver[N<<1], nxt[N<<1], tot;
int dep[N];
int n, rt;
ll a[N], tag[N];
unordered_map<int, ll> mp[N];
void add(int x, int y){
    ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void ins(int x, int d, ll cnt){
    if(!mp[x].count(d)){
        mp[x][d] = cnt + tag[x]; // x 下面的边数
    } else {
        mp[x][d] = max(mp[x][d] - tag[x], 1ll)  + cnt + tag[x];
    }
}
void merge(int x, int y){
    if(mp[x].size() < mp[y].size()){
        swap(mp[x], mp[y]);
        swap(tag[x], tag[y]);
    }
    for(auto t : mp[y]){
        if(t.second){
            ins(x, t.first, max(t.second - tag[y], 1ll));
        }
    }
}
void dfs(int x, int fa){
    dep[x] = dep[fa] + 1;
    for(int i=head[x];i;i=nxt[i]){
        int y = ver[i];
        if(y == fa) continue;
        dfs(y, x);
        merge(x, y);
    }
    if(a[x])
        ins(x, dep[x], a[x]);
    tag[x] ++;
}
int main(){
    scanf("%d%d", &n, &rt);
    for(int i=1;i<=n;i++){
        scanf("%lld", &a[i]);
    }
    for(int i=1;i<n;i++){
        int x, y;scanf("%d%d", &x, &y);
        add(x, y);add(y, x);
    }
    dfs(rt, 0);
    ll res = 0;
    for(auto t : mp[rt]){
        if(t.second) res += max(1ll, t.second - tag[rt]);
    }
    cout << res << endl;
    return 0;
}

方法二:

按照深度分组,建立虚树,然后树形DP求解即可

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "33[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "33[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }
const int N = 200000 + 5;
const int M = 2*N;
int head[N], ver[M], nxt[M];
int dfn[N], rnk[N], cnt;
int dep[N], f[N][20];
int st[N], top, inq[N];
ll a[N];
int n, rt, tot;
vector<int> node[N];
struct Graph{
    int head[N], ver[M], nxt[M], tot;
    void add(int x, int y){
        ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
    }
}G;
void add(int x, int y){
    ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void dfs(int x, int fa){
    dfn[x] = ++cnt, rnk[cnt] = x;
    for(int i=head[x];i;i=nxt[i]){
        if(ver[i] == fa) continue;
        f[ver[i]][0] = x;
        dep[ver[i]] = dep[x] + 1;
        dfs(ver[i], x);
    }
}
int lca(int x, int y){
    if(dep[x] > dep[y]) swap(x, y);
    for(int i=19;i>=0;i--) if(dep[f[y][i]] >= dep[x]) y = f[y][i];
    if(x == y) return x;
    for(int i=19;i>=0;i--) if(f[y][i] != f[x][i]) y = f[y][i], x = f[x][i];
    return f[x][0];
}
void insert(int x){
    if(x == rt) return;
    int t = lca(x, st[top]);
    if(t != st[top]){
        while(top > 1 && dfn[st[top-1]] > dfn[t]){
            G.add(st[top-1], st[top]);
            top --;
        }
        if(dfn[t] > dfn[st[top-1]]){
            G.head[t] = 0;
            G.add(t, st[top]);
            st[top] = t;
        } else {
            G.add(t, st[top--]);
        }
    }
    G.head[x] = 0, st[++top] = x;
}
ll dfs(int x){
    if(inq[x]) return a[x];
    ll res = 0;
    for(int i=G.head[x];i;i=G.nxt[i]){
        int y = G.ver[i];
        ll val = dfs(y);
        if(val) // 没有就不要加
            res += max(val - dep[y] + dep[x], 1ll);
    }
    return res;
}
ll get(int x){
    if(!node[x].size()) return 0;
    sort(node[x].begin(), node[x].end(),[=](int a, int b){return dfn[a] < dfn[b];});
    st[top = 1] = rt; G.tot = 0; G.head[rt] = 0;
    for(auto t : node[x]) insert(t), inq[t] = 1;
    for(int i=1;i<top;i++){
        G.add(st[i], st[i+1]);
    }
    ll res = dfs(rt);
    if(res >= 2) res --;
    for(auto t : node[x]) inq[t] = 0;
    return res;
}
int main(){
    scanf("%d%d", &n, &rt);
    for(int i=1;i<=n;i++){
        scanf("%lld", &a[i]);
    }
    for(int i=1;i<n;i++){
        int x, y;scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    dep[rt] = 1;
    dfs(rt, 0);
    for(int i=1;i<=n;i++){
        node[dep[i]].push_back(i);
    }
    for(int j=1;j<20;j++){
        for(int i=1;i<=n;i++){
            f[i][j] = f[f[i][j-1]][j-1];
        }
    }
    ll res = 0;
    for(int i=1;i<=n;i++){
        res += get(i);
    }
    cout << res <<endl;
    return 0;
}
原文地址:https://www.cnblogs.com/1625--H/p/12731705.html