AcWing 1073. 树的中心

题目传送门

一、思路分析

这个问题是 树形DP 中的一类 经典模型,常被称作 换根DP

同样,先来想一下如何暴力求解该问题:先 枚举 目标节点,然后求解该节点到其他节点的 最远距离

时间复杂度为 \(O(n^2)\),对于本题的 数据规模,十分极限,经测试只能过 \(7/11\),代码:

#include <bits/stdc++.h>

using namespace std;
//直接暴力换根
//暴力办法,可以通过 7/11个数据
const int N = 10010;
const int M = N * 2;
const int INF = 0x3f3f3f3f;

int n;
int h[N], e[M], w[M], ne[M], idx;
int d[N];       //记录以每个结点为根出发可以走的最长距离

//邻接表模板
void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

//暴力版本
void dfs(int u, int father) {
    for (int i = h[u]; ~i; i = ne[i]) {     //遍历u节点的每一条出边
        int j = e[i];
        if (j == father) continue;          //不走回头路
        dfs(j, u);
        int dist = d[j] + w[i];             //i表示u->j的边
        d[u] = max(d[u], dist);             //获取最长距离
    }
}

int main() {
    //邻接表初始化
    memset(h, -1, sizeof h);
    cin >> n;
    for (int i = 1; i < n; i++) {//n-1条边
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    int res = INF;
    //从每一个点出发,分别求一个此点到其它各点的最长距离,然后求一个min
    for (int i = 1; i <= n; i++) {
        memset(d, 0, sizeof d);
        dfs(i, -1);
        res = min(res, d[i]);
    }
    //输出
    printf("%d\n", res);
    return 0;
}

考虑如何优化求解该问题的方法

思考一下:在确定树的 拓扑结构 后单独求一个节点的 最远距离 时,会在该树上去比较哪些 路径 呢?

  • 从当前节点往下,直到子树中某个节点的最长路径

  • 从当前节点往上走到其父节点,再从其父节点出发且不回到该节点的最长路径

此处就要引入 换根DP 的思想了

换根DP 一般分为三个步骤:

  1. 指定任意一个根节点
  2. 一次\(dfs\)遍历,统计出当前子树内的节点对当前节点的贡献
  3. 一次\(dfs\)遍历,统计出当前节点的父节点对当前节点的贡献,然后合并统计答案

那么我们就要先 \(dfs\) 一遍,预处理出当前子树对于根的最大贡献(距离)次大贡献(距离)

处理 次大贡献(距离) 的原因是:

如果 当前节点 是其 父节点子树最大路径 上的点,则 父节点子树最大贡献 不能算作对该节点的贡献

因为我们的路径是 简单路径,不能 走回头路

然后我们再 \(dfs\) 一遍,求解出每个节点的父节点对他的贡献(即每个节点往上能到的最远路径),两者比较,取一个 \(max\)即可

d1[u]:存下u节点向下走的最长路径的长度
d2[u]:存下u节点向下走的第二长的路径的长度
p1[u]:存下u节点向下走的最长路径是从哪一个节点下去的
p2[u]:存下u节点向下走的第二长的路径是从哪一个节点走下去的
up[u]:存下u节点向上走的最长路径的长度

二、实现代码

#include<bits/stdc++.h>

using namespace std;

const int N = 10010;
const int M = N << 1;
const int INF = 0x3f3f3f3f;

int n;
int h[N], e[M], w[M], ne[M], idx;
int d1[N];  //下行最长距离
int d2[N];  //下行次长距离
int up[N];  //上行最长距离
int p1[N];  //下行最长距离是走的哪一个节点获得的

//邻接表模板
void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

//统计出当前子树内的节点对当前节点的贡献
void dfs_down(int u, int father) {
    for (int i = h[u]; ~i; i = ne[i]) {     //遍历每条出边
        int j = e[i];                       //连接的节点j
        if (j == father) continue;          //不走回头路
        dfs_down(j, u);                //换根递归计算以j为根的子树情况
        //dfs1的结果其实已经记录到d1[j]里
        if (d1[j] + w[i] >= d1[u]) {        //如果可以获得更大的距离
            d2[u] = d1[u];                  //原来的最长变为次长
            d1[u] = d1[j] + w[i];           //最长=子结点最长+边权
            p1[u] = j;                      //记录最长是通过子结点j获取
        } else if (d1[j] + w[i] > d2[u])    //如果可以更新次长
            d2[u] = d1[j] + w[i];           //更新次长
    }
}

//统计出当前节点的父节点对当前节点的贡献,然后合并统计答案
void dfs_up(int u, int father) {
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        //j是u的子节点,这里在求j向上走的最长路。
        //分两种情况,如果u向下的最长路经过j,则用次长路更新;否则用最长路更新。
        if (p1[u] == j) up[j] = max(up[u], d2[u]) + w[i];   //用次大更新
        else up[j] = max(up[u], d1[u]) + w[i];              //用最大更新
        //讨论以j为根情况
        dfs_up(j, u);
    }
}

int main() {
    //初始化邻接表
    memset(h, -1, sizeof h);
    cin >> n;
    for (int i = 1; i < n; i++) {//n-1条边
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    //换根DP,两次DFS
    dfs_down(1, -1);
    dfs_up(1, -1);

    //遍历每一个节点,找出它的最大上行距离和最大下行距离,然后取最小值
    int res = INF;
    for (int i = 1; i <= n; i++) res = min(res, max(d1[i], up[i]));
    //输出
    printf("%d\n", res);
    return 0;
}
原文地址:https://www.cnblogs.com/littlehb/p/15786805.html