题目
简化题意
给你一棵树,给你一个访问节点的序列,按照先后顺序去访问序列中的从未经过过的节点,问经过了多少条边。
思路
并查集 $ + LCA$。
用并查集维护每个点是否走过,如果走过了就将该点和他的第一个没被走过的父亲合并。
(LCA) 用来计算距离,在路径上暴跳的时候维护并查集,因为每个点最多被经过一次,复杂度 (O(n))。
Code
#include <cstdio>
#include <cstring>
#include <string>
#include <iostream>
#include <algorithm>
#define MAXN 500001
#define int long long
int n, m, s, pthn, head[MAXN], fa_[MAXN];
int lg[MAXN], fa[MAXN][21], dep[MAXN];
struct Edge {
int next, to;
}pth[MAXN << 1];
void add(int from, int to) {
pth[++pthn].to = to, pth[pthn].next = head[from];
head[from] = pthn;
}
int find(int x) { return fa_[x] == x ? x : fa_[x] = find(fa_[x]); }
void dfs(int u, int father) {
fa[u][0] = father, dep[u] = dep[father] + 1;
for (int i = head[u]; i; i = pth[i].next) {
int x = pth[i].to;
if (x != father) dfs(x, u);
}
}
int lca(int x, int y) {
if (dep[x] < dep[y]) std::swap(x, y);
while (dep[x] > dep[y]) {
x = fa[x][lg[dep[x] - dep[y]] - 1];
}
if (x == y) return x;
for (int k = lg[dep[x]] - 1; k >= 0; --k) {
if (fa[x][k] != fa[y][k]) {
x = fa[x][k];
y = fa[y][k];
}
}
return fa[x][0];
}
int dis(int x, int y, int l) {
return dep[x] + dep[y] - 2 * dep[l];
}
signed main() {
scanf("%lld %lld %lld", &n, &m, &s);
for (int i = 1; i <= n; ++i) fa_[i] = i;
for (int i = 1, u, v; i < n; ++i) {
scanf("%lld %lld", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
}
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i <= n; ++i) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
int ans = 0;
for (int i = 1, t; i <= m; ++i) {
scanf("%lld", &t);
if (fa_[t] == t) {
int l = lca(s, t);
int x = s, y = t;
while (dep[x] >= dep[l]) {
fa_[x] = find(fa[x][0]);
x = fa_[x];
}
while (dep[y] >= dep[l]) {
fa_[y] = find(fa[y][0]);
y = fa_[y];
}
ans += dis(s, t, l);
s = t;
}
}
std::cout << ans << '
';
return 0;
}