HNOI 世界树 虚树

//virtual tree
/*Huyyt*/
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define TS cout<<"!!!"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double eps = 1e-8;
const int dir[8][2] = {{0, 1}, {1, 0}, {0, -1}, { -1, 0}, {1, 1}, {1, -1}, { -1, -1}, { -1, 1}};
const int mod = 1e9 + 7, gakki = 5 + 2 + 1 + 19880611 + 1e9;
const int MAXN = 3e5 + 5, MAXM = 3e5 + 5, MAXQ = 100010, INF = 1e9;
const ll LLINF = (1LL << 50);
int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], tot = 1;
inline void addedge(int u, int v) {
        if (u == v) {
                return ;
        }
        to[++tot] = v;
        nxt[tot] = Head[u];
        Head[u] = tot;
}
template <typename T> inline void read(T&x) {
        char cu = getchar();
        x = 0;
        bool fla = 0;
        while (!isdigit(cu)) {
                if (cu == '-') {
                        fla = 1;
                }
                cu = getchar();
        }
        while (isdigit(cu)) {
                x = x * 10 + cu - '0', cu = getchar();
        }
        if (fla) {
                x = -x;
        }
}
int n;
int jumpfa[MAXN][21];
int sz[MAXN];
int dfn[MAXN], deep[MAXN];
int cnt = 0;
int top;
int s[MAXN];
int k[MAXN], k2[MAXN];
bool del[MAXN];
void dfs1(int x, int fa) {
        jumpfa[x][0] = fa;
        dfn[x] = ++cnt;
        sz[x] = 1;
        for (int i = 1; i <= 19; i++) {
                jumpfa[x][i] = jumpfa[jumpfa[x][i - 1]][i - 1];
        }
        for (int v, i = Head[x]; i; i = nxt[i]) {
                if (v = to[i], v != fa) {
                        deep[v] = deep[x] + 1;
                        dfs1(v, x);
                        sz[x] += sz[v];
                }
        }
}
inline int lca(int x, int y) {
        if (deep[x] < deep[y]) {
                swap(x, y);
        }
        int t = 0;
        while ((1 << t) <= deep[x]) {
                t++;
        }
        t--;
        for (int i = t; i >= 0; i--) {
                if (deep[x] - (1 << i) >= deep[y]) {
                        x = jumpfa[x][i];
                }
        }
        if (x == y) {
                return x;
        }
        for (int i = t; i >= 0; i--) {
                if (jumpfa[x][i] != jumpfa[y][i]) {
                        x = jumpfa[x][i], y = jumpfa[y][i];
                }
        }
        return jumpfa[x][0];
}
inline bool cmp(int a, int b) {
        return dfn[a] < dfn[b];
}
inline void insert_point(int x) {
        if (top == 0) {
                s[++top] = x;
                return ;
        }
        int grand = lca(x, s[top]);
        if (grand == s[top]) {
                s[++top] = x;
                return ;
        }
        while (top >= 1 && dfn[s[top - 1]] >= dfn[grand]) {
                addedge(s[top - 1], s[top]);
                top--;
        }
        if (grand != s[top]) {
                addedge(grand, s[top]);
                s[top] = grand;
        }
        s[++top] = x;
}
int getpoint(int x, int depcha) {
        int aim = x;
        for (int i = 19; i >= 0; i--) {
                if ((1 << i) <= depcha) {
                        depcha -= (1 << i);
                        aim = jumpfa[aim][i];
                }
        }
        return aim;
}
int getdis(int x, int y) {
        return deep[x] + deep[y] - 2 * deep[lca(x, y)];
}
int belong[MAXN];
int ans[MAXN], rest[MAXN];
void getbelong1(int x, int fa) {
        rest[x] = sz[x];
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa) {
                        continue;
                }
                int aim;
                aim = getpoint(v, deep[v] - deep[x] - 1);
                rest[x] -= sz[aim];
                getbelong1(v, x);
                if (belong[v] == 0) {
                        continue;
                }
                if (belong[x] == 0) {
                        belong[x] = belong[v];
                } else {
                        int dis1, dis2;
                        dis1 = getdis(x, belong[x]), dis2 = getdis(x, belong[v]);
                        if (dis1 > dis2 || (dis1 == dis2 && belong[x] > belong[v])) {
                                belong[x] = belong[v];
                        }
                }
        }
}
void getbelong2(int x, int fa) {
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa) {
                        continue;
                }
                if (belong[v] == 0) {
                        belong[v] = belong[x];
                } else {
                        int dis1, dis2;
                        dis1 = getdis(v, belong[x]), dis2 = getdis(v, belong[v]);
                        if (dis1 < dis2 || (dis1 == dis2 && belong[v] > belong[x])) {
                                belong[v] = belong[x];
                        }
                }
                getbelong2(v, x);
        }
}
void getans(int x, int fa) {
        int aim, now;
        ans[belong[x]] += rest[x];
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa) {
                        continue;
                }
                getans(v, x);
                if (deep[v] - deep[x] == 1) {
                        continue;
                }
                if (belong[x] == belong[v]) {
                        aim = getpoint(v, deep[v] - deep[x] - 1);
                        ans[belong[x]] += sz[aim] - sz[v];
                } else {
                        now = v;
                        for (int i = 19; i >= 0; i--) {
                                aim = jumpfa[now][i];
                                if (deep[aim] <= deep[x]) {
                                        continue;
                                }
                                int dis1, dis2;
                                dis1 = getdis(aim, belong[x]), dis2 = getdis(aim, belong[v]);
                                if (dis1 > dis2 || (dis1 == dis2 && belong[x] > belong[v])) {
                                        now = aim;
                                }
                        }
                        aim = getpoint(v, deep[v] - deep[x] - 1);
                        ans[belong[x]] += sz[aim] - sz[now];
                        ans[belong[v]] += sz[now] - sz[v];
                }
        }
}
void init(int x, int fa) {
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa) {
                        continue;
                }
                init(v, x);
        }
        Head[x] = belong[x] = 0;
}
int main() {
        int u, v;
        read(n);
        for (int i = 1; i <= n - 1; i++) {
                read(u), read(v);
                addedge(u, v);
                addedge(v, u);
        }
        deep[1] = 1;
        dfs1(1, 0);
        mem(Head, 0);
        int m, number, xxx;
        read(m);
        while (m--) {
                tot = 1;
                top = 0;
                read(number);
                xxx = number;
                for (int i = 1; i <= number; i++) {
                        read(k[i]);
                        k2[i] = k[i];
                        del[k[i]] = 1;
                        belong[k[i]] = k[i];
                }
                if (belong[1] == 0) {
                        k[++number] = 1;
                }
                sort(k + 1, k + number + 1, cmp);
                for (int i = 1; i <= number; i++) {
                        insert_point(k[i]);
                }
                while (top > 1) {
                        addedge(s[top - 1], s[top]);
                        top--;
                }
                getbelong1(1, 0);
                getbelong2(1, 0);
                //                for (int i = 1; i <= n; i++) {
                //                        cout << rest[i] << " ";
                //                }
                //                cout << endl;
                //                for (int i = 1; i <= n; i++) {
                //                        cout << belong[i] << " ";
                //                }
                //                cout << endl;
                getans(1, 0);
                printf("%d", ans[k2[1]]);
                for (int i = 2; i <= xxx; i++) {
                        printf(" %d", ans[k2[i]]);
                }
                puts("");
                init(1, 0);
                for (int i = 1; i <= number; i++) {
                        del[k[i]] = 0;
                        ans[k[i]] = 0;
                }
        }
        return 0;
}
View Code
原文地址:https://www.cnblogs.com/Aragaki/p/10551687.html