异象石(引理证明)

题面

Adera是Microsoft应用商店中的一款解谜游戏。

异象石是进入Adera中异时空的引导物,在Adera的异时空中有一张地图。

这张地图上有N个点,有N-1条双向边把它们连通起来。

起初地图上没有任何异象石,在接下来的M个时刻中,每个时刻会发生以下三种类型的事件之一:

地图的某个点上出现了异象石(已经出现的不会再次出现;
地图某个点上的异象石被摧毁(不会摧毁没有异象石的点;
向玩家询问使所有异象石所在的点连通的边集的总长度最小是多少。
请你作为玩家回答这些问题。

输入格式

第一行有一个整数N,表示点的个数。

接下来N-1行每行三个整数x,y,z,表示点x和y之间有一条长度为z的双向边。

第N+1行有一个正整数M。

接下来M行每行是一个事件,事件是以下三种格式之一:

”+ x” 表示点x上出现了异象石

”- x” 表示点x上的异象石被摧毁

”?” 表示询问使当前所有异象石所在的点连通所需的边集的总长度最小是多少。

输出格式

对于每个 ? 事件,输出一个整数表示答案。

数据范围

1≤N,M≤10^5,
1≤x,y≤N,
x≠y,
1≤z≤109

输入样例

6
1 2 1
1 3 5
4 1 7
4 5 3
6 4 2
10
+ 3
+ 1
?
+ 6
?
+ 5
?
- 6
- 3
?

输出样例

5
14
17
10

题解

思路和蓝书上一样, 按dfs序对答案进行修改, 思路就不再说了,

主要说下为什么 按dfs这样算 是答案的两倍

证明(数学归纳法):
(f(x,y)=d[x]+d[y]-2*d[lca(x,y)])

当 k = 1, ans=0
当 k = 2, (ans=f(a_1,a_2)+f(a_2,a_1)), 答案显然是两倍
当 k = 3, (ans=f(a_1,a_2)+f(a_2,a_3)+f(a_3,a_1))
     对于n = 2少了个 (f(a_2,a_1)), 多了(f(a_2,a_3)+f(a_3,a_1))
     即多了 (2*d[3]+2*d[lca(a_1,a_2)]-2*d[lca(a_2,a_3)]-2*d[lca(a_1,a_3)])
     我们发现都乘了个 2, 这便是加入a_3之后对答案 两倍 的影响
     由于我们是严格按照 dfs序 算的, 所以 a_3 对于 a_2有
     1.a_3是a_2的子节点, 那么刚才一长串就是 a_2到a_1 和 a_2到a_3的距离的两倍
     2.a_3和a_2在不同的两颗子树上([lca(a_2,a_3) eq a_2]), 刚才的一长串就是 a_2到a_1 和 a_3到(lca(a_1,a_2,a_3))距离的两倍
当 k = n, (ans=sum_{i=1}^{n-1}f(a_1,a_i)+f(a_n,a_1)), 比 k = n - 1 多出的部分也分为
     1.(a_{n-1}是a_n)的子节点, 那么多出部分为....
     2.(a_{n-1}和a_n)在不同的两颗子树上([lca(a_{n-1},a_n) eq a_2]), 多出的部分为....

就简单证明一下, 主要不懂引理, 就不会写(暴力超时), 具体代码如下

#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb push_back
#define mp make_pair
#define sqr(n) (n)*(n)
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
#define IO ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr)
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef double db;

const int N = 1e5 + 5;

int n, m, _, k;
int h[N], ne[N << 1], to[N << 1], co[N << 1], tot;
int dfn[N], cnt, d[N], f[N][30], t;
ll dist[N], ans;
char s[3];
set<PII> st;

void add(int u, int v, int c) {
    ne[++tot] = h[u]; h[u] = tot; to[tot] = v; co[tot] = c;
}

void bfs(int s) {
    queue<int> q;
    q.push(s); d[s] = 1; dist[s] = 0;
    rep (i, 0, t) f[s][i] = 0;

    while (!q.empty()) {
        int x = q.front(); q.pop();
        for (int i = h[x]; i; i = ne[i]) {
            int y = to[i];
            if (d[y]) continue;
            d[y] = d[x] + 1;
            dist[y] = dist[x] + co[i];
            f[y][0] = x;

            for (int j = 1; j <= t; ++j) 
                f[y][j] = f[f[y][j - 1]][j - 1];

            q.push(y);
        }
    }
}

int lca(int x, int y) {
    if (d[x] > d[y]) swap(x, y);
    per (i, t, 0) 
        if (d[f[y][i]] >= d[x]) y = f[y][i];

    if (x == y) return x;

    per (i, t, 0)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];

    return f[x][0]; 
}

void dfs(int u) {
    dfn[u] = ++cnt;
    for (int i = h[u]; i; i = ne[i]) {
        int y = to[i];
        if (dfn[y]) continue;
        dfs(y);
    }
}

void worka() {
    auto it = st.upper_bound({ dfn[k], k });
    int x, y;
    if (it == st.begin() || it == st.end()) x = st.begin()->se, y = st.rbegin()->se;
    else  x = it->se, y = (--it)->se;

    ans += (dist[k] << 1) - ((dist[lca(k, x)] + dist[lca(k, y)] - dist[lca(x, y)]) << 1);
    st.insert({ dfn[k], k });
}

void workb() {
    if (st.size() == 1) { st.clear(); return; }

    auto it = st.lower_bound({ dfn[k], k }), ita = it; ++ita;
    int x, y;
    if (ita == st.end()) y = st.begin()->se;
    else y = ita->se;
    if (it == st.begin()) x = st.rbegin()->se;
    else x = (--it)->se;

    ans -= (dist[k] << 1) - ((dist[lca(k, x)] + dist[lca(k, y)] - dist[lca(x, y)]) << 1);
    st.erase(--ita);
}

int main() {
    IO; cin >> n;
    rep(i, 2, n) {
        int u, v, c; cin >> u >> v >> c;
        add(u, v, c); add(v, u, c);
    }

    t = log2(n - 1) + 1;
    dfs(1); bfs(1);

    cin >> m;
    rep(i, 1, m) {
        cin >> s;
        if (s[0] == '?') cout << (ans >> 1) << '
';
        else {
            cin >> k;
            if (st.empty()) st.insert({ dfn[k], k });
            else if (s[0] == '+') worka();
            else workb();
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/2aptx4869/p/13513541.html