bzoj3910 火车

Description

A 国有n 个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一

路径。现在有火车在城市 a,需要经过m 个城市。火车按照以下规则行驶:每次

行驶到还没有经过的城市中在 m 个城市中最靠前的。现在小 A 想知道火车经过

这m 个城市后所经过的道路数量。

Input

第一行三个整数 n、m、a,表示城市数量、需要经过的城市数量,火车开始

时所在位置。

接下来 n-1 行,每行两个整数 x和y,表示 x 和y之间有一条双向道路。

接下来一行 m 个整数,表示需要经过的城市。

Output

一行一个整数,表示火车经过的道路数量。

Sample Input

5 4 2
1 2
2 3
3 4
4 5
4 3 1 5

Sample Output

9

Hint

N<=500000 ,M<=400000

(Lca) + 并查集

​ 一开始想到用(lca)求走过的边数了,但是将点标记的时候我用的是暴力向上跳,结果(Tle)了。

​ 看了题解发现可以用并查集维护这个点是否被走过了。具体看代码。

#include <iostream>
#include <cstdio>
#include <cctype>

using namespace std;

inline long long read() {
    long long s = 0, f = 1; char ch;
    while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
    for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
    return s * f;
}

const int N = 5e5 + 5, M = 4e5 + 5;
int n, m, x, cnt;
long long ans;
int a[M], f[N], fa[N][21], dep[N], head[N];
struct edge { int to, nxt; } e[N << 1]; 

void add(int x, int y) {
    e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
} 

int find(int x) {
    return x == f[x] ? x : f[x] = find(f[x]);
}

void get_dep(int x, int Fa) {
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to; if(y == Fa) continue;
        dep[y] = dep[x] + 1; fa[y][0] = x;
        get_dep(y, x);
    }
}

void make_fa() {
    for(int i = 1;i <= 20; i++) {
        for(int j = 1;j <= n; j++) {
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
        }
    }
}

int LCA(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    for(int i = 20;i >= 0; i--) {
        if(dep[x] - dep[y] >= (1 << i)) x = fa[x][i];
    } 
    if(x == y) return x;
    for(int i = 20;i >= 0; i--) {
        if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    }
    return fa[x][0];
}

void init() {
    n = read(); m = read(); x = read();
    for(int i = 1;i <= n; i++) f[i] = i;
    for(int i = 1, x, y;i <= n - 1; i++) {
        x = read(); y = read(); 
        add(x, y); add(y, x);
    }
    for(int i = 1;i <= m; i++) a[i] = read();
}

void work() {
    get_dep(x, 0); make_fa();
    for(int i = 1;i <= m; i++) {
        int tx = find(x), ty = find(a[i]);
        if(tx == ty) continue;
        int lca = LCA(x, a[i]);
        ans += dep[x] + dep[a[i]] - (dep[lca] << 1);
        lca = find(lca);
        int tmp = tx;
        while(find(tmp) != lca) {
            int father = find(tmp);
            f[father] = lca; tmp = fa[father][0];
        }
        tmp = ty;
        while(find(tmp) != lca) {
            int father = find(tmp);
            f[father] = lca; tmp = fa[father][0];
        }
        x = a[i];
    }
    printf("%lld", ans);
}

int main() {

    init();
    work();

    return 0;
}
原文地址:https://www.cnblogs.com/czhui666/p/13568169.html