思路:
经典树形dp
dp[i][0]表示i的子树中以i为端点的最大链
dp[i][1]表是整棵树中除去i的子树剩下的部分以i为端点的最大链
最后答案就是以i为端点的最大链和次大链拼起来(除了一些特殊情况,比如一条链更大,或者只有一条链)
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //head const int N = 1e5 + 5; vector<int> g[N]; int a[N], ans[N]; int dp[N][2]; void dfs1(int u, int o) { dp[u][0] = a[u]; for (int v : g[u]) { if(v != o) { dfs1(v, u); dp[u][0] = max(dp[u][0], a[u] + dp[v][0]); } } } void dfs2(int u, int o) { if(u == 1) dp[u][1] = a[u]; int v1, mx1 = INT_MIN, v2, mx2 = INT_MIN; for (int v : g[u]) { if(v != o) { int tmp = dp[v][0] + a[u]; if(tmp > mx1) { mx2 = mx1; v2 = v1; mx1 = tmp; v1 = v; } else if(tmp == mx1 || tmp > mx2) { mx2 = tmp; v2 = v; } } } if(dp[u][1] > mx1) { mx2 = mx1; v2 = v1; mx1 = dp[u][1]; v1 = -1; } else if(dp[u][1] == mx1 || dp[u][1] > mx2){ mx2 = dp[u][1]; v2 = -1; } ans[u] = a[u]; if(mx1 != INT_MIN && mx2 != INT_MIN) ans[u] = max(ans[u], mx1 + mx2 - a[u]); else if(mx1 != INT_MIN) ans[u] = max(ans[u], mx1); else ans[u] = max(ans[u], mx2); ans[u] = max(ans[u], mx1); ans[u] = max(ans[u], mx2); for (int v : g[u]) { if(v != o) { dp[v][1] = a[v]; if(v == v1) { dp[v][1] = max(dp[v][1], mx2 + a[v]); } else { dp[v][1] = max(dp[v][1], mx1 + a[v]); } dfs2(v, u); } } } int main() { int n, m, u, v, q; scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i < n; i++) { scanf("%d %d", &u, &v); g[u].pb(v); g[v].pb(u); } dfs1(1, 1); dfs2(1, 1); while(m--) { scanf("%d", &q); printf("%d ", ans[q]); } return 0; }