All You Can Code 2008 (Romanian Contest) A

A - Tree Search

思路:

经典树形dp

dp[i][0]表示i的子树中以i为端点的最大链

dp[i][1]表是整棵树中除去i的子树剩下的部分以i为端点的最大链

最后答案就是以i为端点的最大链和次大链拼起来(除了一些特殊情况,比如一条链更大,或者只有一条链)

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

const int N = 1e5 + 5;
vector<int> g[N];
int a[N], ans[N];
int dp[N][2];
void dfs1(int u, int o) {
    dp[u][0] = a[u];
    for (int v : g[u]) {
        if(v != o) {
            dfs1(v, u);
            dp[u][0] = max(dp[u][0], a[u] + dp[v][0]);
        }
    }
}
void dfs2(int u, int o) {
    if(u == 1) dp[u][1] = a[u];
    int v1, mx1 = INT_MIN, v2, mx2 = INT_MIN;
    for (int v : g[u]) {
        if(v != o) {
            int tmp = dp[v][0] + a[u];
            if(tmp > mx1) {
                mx2 = mx1;
                v2 = v1;
                mx1 = tmp;
                v1 = v;
            }
            else if(tmp == mx1 || tmp > mx2) {
                mx2 = tmp;
                v2 = v;
            }
        }
    }
    if(dp[u][1] > mx1) {
        mx2 = mx1;
        v2 = v1;
        mx1 = dp[u][1];
        v1 = -1;
    }
    else if(dp[u][1] == mx1 || dp[u][1] > mx2){
        mx2 = dp[u][1];
        v2 = -1;
    }
    ans[u] = a[u];
    if(mx1 != INT_MIN && mx2 != INT_MIN) ans[u] = max(ans[u], mx1 + mx2 - a[u]);
    else if(mx1 != INT_MIN) ans[u] = max(ans[u], mx1);
    else ans[u] = max(ans[u], mx2);
    ans[u] = max(ans[u], mx1);
    ans[u] = max(ans[u], mx2);
    for (int v : g[u]) {
        if(v != o) {
            dp[v][1] = a[v];
            if(v == v1) {
                dp[v][1] = max(dp[v][1], mx2 + a[v]);
            }
            else {
                dp[v][1] = max(dp[v][1], mx1 + a[v]);
            }
            dfs2(v, u);
        }
    }

}
int main() {
    int n, m, u, v, q;
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i < n; i++) {
        scanf("%d %d", &u, &v);
        g[u].pb(v);
        g[v].pb(u);
    }
    dfs1(1, 1);
    dfs2(1, 1);
    while(m--) {
        scanf("%d", &q);
        printf("%d
", ans[q]);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/widsom/p/10097238.html