HDU 4616 Game(经典树形dp+最大权值和链)

http://acm.hdu.edu.cn/showproblem.php?pid=4616

题意:
给出一棵树,每个顶点有权值,还有存在陷阱,现在从任意一个顶点出发,并且每个顶点只能经过一次,如果经过了c个陷阱就不能再走了,计算最大能获得的权值和。

思路:
有点像树链剖分,对于一个以u为根的子树,因为每个顶点只能经过一次,那我们只能选择它的一个子树往下走。就像是把这棵树分成许多链,最后再连接起来。

这道题目麻烦的地方是陷阱的处理,用d【u】【j】【0/1】表示以u为根的某一子节点经过j个陷阱后到达u的最大权值和,0/1表示起点是否有陷阱。

假设当前到达u时经过了k个陷阱,分下面几种情况进行讨论:

①如果k==c,那么起点和终点至少有一个是陷阱(可能有些人会认为终点一定会是陷阱,这样是没错的,因为起点和终点时相对的,你也可以把起点看做终点)。

②如果k<c,那么起点和终点是否是陷阱是任意的,可以有也可以没有。

 1 #include<iostream>
 2 #include<algorithm>
 3 #include<cstring>
 4 #include<cstdio>
 5 #include<sstream>
 6 #include<vector>
 7 #include<stack>
 8 #include<queue>
 9 #include<cmath>
10 #include<map>
11 #include<set>
12 using namespace std;
13 typedef long long ll;
14 typedef pair<int,int> pll;
15 const int INF = 0x3f3f3f3f;
16 const int maxn = 50000 + 5;
17 
18 int n,c;
19 int ans;
20 int val[maxn], trap[maxn];
21 int d[maxn][5][2];
22 vector<int> G[maxn];
23 
24 void dfs(int u, int fa)
25 {
26     d[u][trap[u]][trap[u]]=val[u];
27 
28     for(int i=0;i<G[u].size();i++)
29     {
30         int v=G[u][i];
31         if(v==fa)  continue;
32         dfs(v,u);
33         
34         //计算以u为根的子树所能获得的最大值,也就是将子树的链进行连接
35         for(int j=0;j<=c;j++)
36         {
37             for(int k=0;j+k<=c;k++)
38             {
39                 if(j!=c)   ans=max(ans,d[u][j][0]+d[v][k][1]);
40                 if(k!=c)   ans=max(ans,d[u][j][1]+d[v][k][0]);
41                 if(j+k<c)  ans=max(ans,d[u][j][0]+d[v][k][0]);  //起点和终点都可以为非陷阱
42                 if(k+j<=c) ans=max(ans,d[u][j][1]+d[v][k][1]);  //起点和终点都可以为陷阱
43             }
44         }
45 
46 
47         for(int j=0;j+trap[u]<=c;j++)  //更新以u的根的子树中权值最大的链
48         {
49             d[u][j+trap[u]][0]=max(d[u][j+trap[u]][0],d[v][j][0]+val[u]);
50             //这儿要注意一下,如果j=0时,要么就不能从有陷阱的起点出发
51             if(j!=0) d[u][j+trap[u]][1]=max(d[u][j+trap[u]][1],d[v][j][1]+val[u]);
52         }
53     }
54 }
55 
56 int main()
57 {
58     //freopen("in.txt","r",stdin);
59     int T;
60     scanf("%d",&T);
61     while(T--)
62     {
63         scanf("%d%d",&n,&c);
64         for(int i=0;i<n;i++)  G[i].clear();
65 
66         for(int i=0;i<n;i++)
67             scanf("%d%d",&val[i],&trap[i]);
68 
69         for(int i=1;i<n;i++)
70         {
71             int u,v;
72             scanf("%d%d",&u,&v);
73             G[u].push_back(v);
74             G[v].push_back(u);
75         }
76 
77         ans=0;
78         memset(d,0,sizeof(d));
79         dfs(0,-1);
80         printf("%d
",ans);
81     }
82     return 0;
83 }
原文地址:https://www.cnblogs.com/zyb993963526/p/7223861.html