闇の連鎖 树上LCA + 树上差分

题面

传说中的暗之连锁被人们称为 Dark。

Dark 是人类内心的黑暗的产物,古今中外的勇者们都试图打倒它。

经过研究,你发现 Dark 呈现无向图的结构,图中有 N 个节点和两类边,一类边被称为主要边,而另一类被称为附加边。

Dark 有 N – 1 条主要边,并且 Dark 的任意两个节点之间都存在一条只由主要边构成的路径。

另外,Dark 还有 M 条附加边。

你的任务是把 Dark 斩为不连通的两部分。

一开始 Dark 的附加边都处于无敌状态,你只能选择一条主要边切断。

一旦你切断了一条主要边,Dark 就会进入防御模式,主要边会变为无敌的而附加边可以被切断。

但是你的能力只能再切断 Dark 的一条附加边。

现在你想要知道,一共有多少种方案可以击败 Dark。

注意,就算你第一步切断主要边之后就已经把 Dark 斩为两截,你也需要切断一条附加边才算击败了 Dark。

输入格式

第一行包含两个整数 N 和 M。

之后 N – 1 行,每行包括两个整数 A 和 B,表示 A 和 B 之间有一条主要边。

之后 M 行以同样的格式给出附加边。

输出格式

输出一个整数表示答案。

数据范围

N≤100000,M≤200000,数据保证答案不超过231−1

输入样例:

4 1 
1 2 
2 3 
1 4 
3 4 
```text
输出样例:
```text
3

题解

附加边会产生环,

对于环外的边直接斩断, 附加边随便

对于环上的边, 如果这条边在两个环以上, 那么要站短多条附加边才可, 贡献为零

对于环上, 只在一个环上, 那么就斩掉这条边, 再把响应附加边斩断, 贡献为1

那就统计每条边在多少个环上呗, 附加边连接 (x, y), 那么 x, y 到 LCA(x, y)上的边环数+1

我们直接树上差分, 让 x, y的 value + 1, LCA(x, y)的 value -= 2, 深度遍历算就行了

#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb push_back
#define mp make_pair
#define sqr(n) (n)*(n)
#define rep(i,a,b) for(int i=a;i<=(b);++i)
#define per(i,a,b) for(int i=a;i>=(b);--i)
#define IO ios::sync_with_stdio(0); cin.tie(0);
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef double db;

const int N = 1e5 + 5;

int n, m, _, k, t;
int h[N], to[N << 1], ne[N << 1], co[N << 1], tot;
int f[N][20], d[N], dist[N], cnt[N];
queue<int> q;

void add(int u, int v, int c) {
    ne[++tot] = h[u]; h[u] = tot; to[tot] = v; co[tot] = c;
}

void bfs(int s) {
    q.push(s); d[s] = 1; dist[s] = 0;
    rep (i, 0, t) f[s][i] = 0;

    while (!q.empty()) {
        int x = q.front(); q.pop();
        for (int i = h[x]; i; i = ne[i]) {
            int y = to[i];
            if (d[y]) continue;
            d[y] = d[x] + 1;
            dist[y] = dist[x] + co[i];
            f[y][0] = x;

            for (int j = 1; j <= t; ++j) 
                f[y][j] = f[f[y][j - 1]][j - 1];
            
            q.push(y);
        } 
    }
}

int lca(int x, int y) {
    if (d[x] > d[y]) swap(x, y);
    per (i, t, 0) 
        if (d[f[y][i]] >= d[x]) y = f[y][i];
    
    if (x == y) return x;
    
    per (i, t, 0)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];

    return f[x][0]; 
}

void dfs(int u, int fa) {
    for (int i = h[u]; i; i = ne[i]) {
        int y = to[i];
        if (y == fa) continue;
        dfs(y, u);
        
        //cout << u << ' ' << y << ' ' << cnt[y] << '
';
        
        if (cnt[y] == 0) k += m;
        else if (cnt[y] == 1) ++k, ++cnt[u];
        else cnt[u] += cnt[y];
    }
}

int main() {
    IO;
    cin >> n >> m;
        
    tot = 0;
    rep (i, 1, n) h[i] = d[i] = 0;
    t = log2(n - 1) + 1;

    rep (i, 2, n) {
        int u, v; cin >> u >> v;
        add(u, v, 0); add(v, u, 0);
    }

    bfs(1);

    rep (i, 1, m) {
        int x, y; cin >> x >> y;
        ++cnt[x], ++cnt[y], cnt[lca(x, y)] -= 2;
    }
        
    dfs(1, 0);
    cout << k;
    return 0;
}
原文地址:https://www.cnblogs.com/2aptx4869/p/13277147.html