POJ 2486 Apple Tree [树状DP]

题目:一棵树,每个结点上都有一些苹果,且相邻两个结点间的距离为1。一个人从根节点(编号为1)开始走,一共可以走k步,问最多可以吃多少苹果。

思路:这里给出数组的定义:

dp[0][x][j] 为从结点x开始走,一共走j步,且j步之后又回到x点时最多能吃到的苹果数。

dp[1][x][j] 为从结点x开始走,一共走j步最多能吃到的苹果数(不必再回到x点)。之所以要定义上面的一种状态是因为在求第二种状态时需要用到。

下面介绍递推公式。

对于结点x,假设它目前要访问的孩子为y,则1...(y-1)已经遍历过。此时有:

dp[0][x][j+2] = max(dp[0][x][j], dp[0][x][m] + dp[0][y][j-m])  

注:dp[0][x][m]在每次dp后都会进行更新,此时的dp[0][x][m]实际上只是遍历过孩子结点1...(y-1)的情况。等号左边j之所以要加2是因为右面的总距离没有考虑从点x到y以及从y再回到x的距离,在这里要加上。

dp[1][x][j+1] = max(dp[1][x][j+1], dp[0][x][m] + dp[1][y][j-m])

注:遍历y结点,且不再回来。j加1表示只需要走一次从x到y的边。

dp[1][x][j+2] = max(dp[1][x][j+2], dp[1][x][m] + dp[0][y][j-m])

注:遍历y结点且又回到结点x,则必然是从1...(y-1)中的某个结点有去无回。

另外在dp的过程中,对于结点x的每个孩子都需要枚举走过的步数。而这个步数的枚举从小到大还是从大到小结果是不一样的。这就要看到前面递推公式里的定义,要求dp[0/1][x][m]存储的是遍历y前面的孩子结点的dp值。因此步数要从大到小枚举,这样每次值更新后都不会影响到下次dp。不然会wa。当然,也可以在每次dp前将dp[0/1][x][m]这两个值存储起来,这样就不用考虑这个问题了。

 1 #include<stdio.h>
 2 #include<string.h>
 3 #include<algorithm>
 4 #define maxn 105
 5 #define maxk 205
 6 using namespace std;
 7 int n, k;
 8 int w[maxn];
 9 bool vis[maxn];
10 struct node
11 {
12     int v, next;
13 }edge[maxn<<1];
14 int num_edge, head[maxn];
15 void init_edge()
16 {
17     num_edge = 0;
18     memset(head, -1, sizeof(head));
19 }
20 void addedge(int a,int b)
21 {
22     edge[num_edge].v = b;
23     edge[num_edge].next = head[a];
24     head[a] = num_edge++;
25 }
26 int dp[2][maxn][maxk];
27 void getdp(int x)
28 {
29     vis[x] = 1;
30     for (int i = 0; i <= k; i++)
31         dp[0][x][i] = dp[1][x][i] = w[x];
32     for (int i = head[x]; i != -1; i = edge[i].next)
33     {
34         int v = edge[i].v;
35         if (vis[v]) continue;
36         getdp(v);
37         for (int j = k; j >= 0; j--)
38             for (int m = 0; m <= j; m++)
39             {
40                 dp[0][x][j+2] = max(dp[0][x][j+2], dp[0][x][m] + dp[0][v][j-m]);
41                 dp[1][x][j+1] = max(dp[1][x][j+1], dp[0][x][m] + dp[1][v][j-m]);
42                 dp[1][x][j+2] = max(dp[1][x][j+2], dp[1][x][m] + dp[0][v][j-m]);
43             }
44     }
45 }
46 int main()
47 {
48     while (~scanf("%d%d",&n,&k))
49     {
50         init_edge();
51         memset(vis, 0, sizeof(vis));
52         for (int i = 1; i <= n; i++)
53             scanf("%d",&w[i]);
54         for (int i = 1; i < n; i++)
55         {
56             int a, b;
57             scanf("%d%d",&a,&b);
58             addedge(a, b);
59             addedge(b, a);
60         }
61         getdp(1);
62         printf("%d
", dp[1][1][k]);
63     }
64     return 0;
65 }
原文地址:https://www.cnblogs.com/fenshen371/p/3281092.html