HRBUST

HRBUST - 2358

思路:dfs序 + 树状数组

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

const int N = 1e5 + 10;
vector<int>g[N];
int anc[20][N], deep[N], l[N], r[N], bit[N], n, now = 1;
bool vis[N];
void add(int x, int a) {
    while(x <= n) bit[x] += a, x += x&-x;
}
int query(int x) {
    int ans = 0;
    while(x) ans += bit[x], x -= x&-x;
    return ans;
}
void dfs(int u) {
    l[u] = now;
    for (int v : g[u]) {
        anc[0][v] = u;
        for (int i = 1; i < 20; i++) anc[i][v] = anc[i-1][anc[i-1][v]];
        now++;
        deep[v] = deep[u] + 1;
        dfs(v);
    }
    r[u] = now;
}
int lca(int u, int v) {
    if(deep[u] < deep[v]) swap(u, v);
    for (int i = 19; i >= 0; i--) if(deep[anc[i][u]] >= deep[v]) u = anc[i][u];
    if(u == v) return u;
    for (int i = 19; i >= 0; i--) if(anc[i][u] != anc[i][v]) u = anc[i][u], v = anc[i][v];
    return anc[0][u];
}
int dis(int u, int v) {
    int a = lca(u, v);
    return deep[u] - query(l[u]) + deep[v] - query(l[v]) - 2*(deep[a] - query(l[a]));
}
int main() {
    int T, q, p, ty, x, y;
    for (int i = 0; i < 20; i++) anc[i][1] = 1;
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) deep[i] = 0, g[i].clear(), vis[i] = false;
        for (int i = 2; i <= n; i++) {
            scanf("%d", &p);
            g[p].pb(i);
        }
        now = 1;
        dfs(1);
        scanf("%d", &q);
        for (int i = 1; i <= n; i++) bit[i] = 0;
        while(q--) {
            scanf("%d", &ty);
            if(ty == 1) {
                scanf("%d %d", &x, &y);
                if(vis[x] || vis[y]) printf("-1
");
                else printf("%d
", dis(x, y));
            }
            else {
                scanf("%d", &x);
                if(!vis[x]) {
                    vis[x] = true;
                    add(l[x], 1);
                    add(r[x]+1, -1);
                }
            }
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/widsom/p/9549698.html