DFS序+线段树 hihoCoder 1381 Little Y's Tree(树的连通块的直径和)

 

题目链接

#1381 : Little Y's Tree

时间限制:24000ms
单点时限:4000ms
内存限制:512MB

描述

小Y有一棵n个节点的树,每条边都有正的边权。

小J有q个询问,每次小J会删掉这个树中的k条边,这棵树被分成k+1个连通块。小J想知道每个连通块中最远点对距离的和。

这里的询问是互相独立的,即每次都是在小Y的原树上进行操作。

输入

第一行一个整数n,接下来n-1行每行三个整数u,v,w,其中第i行表示第i条边边权为wi,连接了ui,vi两点。

接下来一行一个整数q,表示有q组询问。

对于每组询问,第一行一个正整数k,接下来一行k个不同的1到n-1之间的整数,表示删除的边的编号。

1<=n,q,Σk<=105, 1<=w<=109

输出

共q行,每行一个整数表示询问的答案。

题解:

  首先考虑给出两个点集,如何求这两个点集合并之后的直径,方法是把两个点集的直径分别求出来,然后对于这4个点,求出两两之间距离的最大值。

  于是可以按dfs序建立线段树,然后求出每个区间的直径。

  而对于一个询问,删掉k条边,每棵子树都对应的dfs序中的若干区间,而且区间总个数不会超过2k,对于每个区间可以在线段树中查询。

  时间复杂度O(nlog^2n)。

 代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int N = 1e5 + 5;
const int D = 20;

struct Edge {
    int u, v, w;
};

int L[N], R[N], p[N], rt[N][D], dep[N];
ll d[N];
int dfs_clock;
vector<Edge> edges;
vector<int> id[N];
int n, m;

void init_edge() {
    edges.clear();
    for (int i=1; i<=n; ++i) id[i].clear();
    m = 0;
}

void add_edge(int u, int v, int w) {
    edges.push_back((Edge){u, v, w});
    m = edges.size();
    id[u].push_back(m-1);
}

void DFS(int u, int fa) {
    L[u] = ++dfs_clock;
    p[dfs_clock] = u;
    dep[u] = dep[fa] + 1;
    rt[u][0] = fa;
    for (int i: id[u]) {
        Edge &e = edges[i];
        if (e.v == fa) continue;
        d[e.v] = d[u] + e.w;
        DFS(e.v, u);
    }
    R[u] = dfs_clock;
}

void init_LCA() {
    for (int j=1; j<D; ++j) {
        for (int i=1; i<=n; ++i) {
            rt[i][j] = rt[rt[i][j-1]][j-1];
        }
    }
}

int LCA(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i=0; i<D; ++i) {
        if ((dep[u]-dep[v]) >> i & 1) u = rt[u][i];
    }
    if (u == v) return u;
    for (int i=D-1; i>=0; --i) {
        if (rt[u][i] != rt[v][i]) {
            u = rt[u][i];
            v = rt[v][i];
        }
    }
    return rt[u][0];
}

ll dis(int u, int v) {
    return d[u] + d[v] - 2 * d[LCA(u, v)];
}

struct Node {
    ll d;
    int a, b;
    Node(ll d=0, int a=0, int b=0) : d(d), a(a), b(b) {}
    bool operator < (const Node &rhs) const {
        return d < rhs.d;
    }
};

Node nd[N<<2];

Node better(Node x, Node y) {
    if (x.d == -1) return y;
    if (y.d == -1) return x;
    Node z1 = Node(dis(x.a, y.a), x.a, y.a);
    Node z2 = Node(dis(x.a, y.b), x.a, y.b);
    Node z3 = Node(dis(x.b, y.a), x.b, y.a);
    Node z4 = Node(dis(x.b, y.b), x.b, y.b);
    return max({x, y, z1, z2, z3, z4});
}

#define lch o << 1
#define rch o << 1 | 1

void build(int o, int l, int r) {
    if (l == r) {
        nd[o] = Node(0, p[l], p[l]);
        return ;
    }
    int mid = l + r >> 1;
    build(lch, l, mid);
    build(rch, mid+1, r);
    nd[o] = better(nd[lch], nd[rch]);
}

Node query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return nd[o];
    }
    int mid = l + r >> 1;
    Node ret = Node(-1, 0, 0);
    if (ql <= mid) ret = better(ret, query(lch, l, mid, ql, qr));
    if (qr > mid) ret = better(ret, query(rch, mid+1, r, ql, qr));
    return ret;
}

bool cmp(int a, int b) {
    return L[a] < L[b];
}

int x[N], s[N];
vector<int> edge[N];
ll ans;

void prepare() {
    dfs_clock = 0;
    dep[0] = 0;
    d[1] = 0;
    
    DFS(1, 0);
    init_LCA();
    build(1, 1, n);
}

void DFS(int u) {
    Node res = Node(-1, x[u], x[u]);
    int ql = L[x[u]], qr = R[x[u]];
    for (int v: edge[u]) {
        DFS(v);
        res = better(res, query(1, 1, n, ql, L[x[v]]-1));
        ql = R[x[v]] + 1;
    }
    res = better(res, query(1, 1, n, ql, qr));
    ans += res.d;
}

int main() {
    scanf("%d", &n);
    init_edge();
    for (int i=1; i<n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add_edge(u, v, w);
        add_edge(v, u, w);
    }
    
    prepare();

    int q, k;
    scanf("%d", &q);
    while (q--) {
        scanf("%d", &k);
        int idx;
        for (int i=1; i<=k; ++i) {
            scanf("%d", &idx);
            idx--;
            Edge &e = edges[idx*2];
            if (dep[e.u] > dep[e.v]) x[i] = e.u;
            else x[i] = e.v;
        }
        sort(x+1, x+1+k, cmp);

        x[0] = 1;
        int nn = 1;
        s[nn] = 0;
        for (int i=0; i<=k; ++i) edge[i].clear();
        //s[nn]:0~k, x[s[n]]:1 or x[1~k]
        for (int i=1; i<=k; ++i) {
            while (!(L[x[s[nn]]] <= L[x[i]] && R[x[i]] <= R[x[s[nn]]])) nn--;
            edge[s[nn]].push_back(i);
            s[++nn] = i;
        }

        ans = 0;
        DFS(0);
        printf("%lld
", ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Running-Time/p/5914289.html