[POJ1155]TELE(树形背包dp)

[POJ1155]TELE(树形背包dp)

看到这道题的第一眼我把题目看成了TLE

这道题是树形背包dp的经典例题

题目描述(大概的):

给你一棵树,每条边有一个cost,每个叶节点有一个earn

要求在earn的和大于等于cost的和的前提下问最多能连接到多少个叶节点

思路:

这道题卡了我0.5month

(因为我太懒了)

核心思路

用dp[x][k]表示x为根的子树里连接到k个叶节点时最大的利润(earn和-cost和)

那么for嵌套顺序应当是

1 for(int s=cd[x];s;s--/*int d=cd[t];d;d-- <-- It is wrong*/)
2             {
3                 for(int d=min(cd[t],s);d;d--/*int s=cd[x];s>=d;s-- <-- It is wrong*/)
4                 {
5                     if(dp[x][s-d]!=-2147483647){dp[x][s]=max(dp[x][s],dp[x][s-d]+dp[t][d]-e[i].v);}
6                 }
7             }

没错,也就是外层枚举x的体积,内层枚举t的体积,都是从大到小

至于为什么?

因为t的所有可能体积的dp你只能将其中一个装进x的dp里

听说这叫分组背包?反正我看着也差不多

此后就可以开始快乐地dfs了

(为什么看不少人都用滚动数组呢...dp[3001][3001]也不大啊)

剩下还可以加一些小小的优化

看代码吧

 1 #include<cstdio>
 2 int max(int a,int b){return a>b?a:b;}
 3 int min(int a,int b){return a<b?a:b;}
 4 struct sumireko
 5 {
 6     int to,ne,v;
 7 }e[3001];
 8 int he[3001],cnt;
 9 void addline(int f,int t,int vin)
10 {
11     e[++cnt].to=t;
12     e[cnt].ne=he[f];
13     e[cnt].v=vin;
14     he[f]=cnt;
15     return;
16 }
17 
18 //bool v[3001];
19 //写完才注意到是有向边,所以v意义不大,完全可以全部抛弃
20 //但删完之后反而变慢了
21 //越简化越慢也是有毒
22 int n,m,cd[3001];//其中cd数组与下面的findd()一起使用,用于找出每个节点的子树有几个叶节点,从而省点时间(大概)
23 
24 int findd(int x)
25 {
26     if(x>(n-m))return cd[x]=1;
27     for(int i=he[x];i;i=e[i].ne)
28     {
29         int t=e[i].to;
30         cd[x]+=findd(t);
31     }
32     return cd[x];
33 }
34 
35 int dp[3001][3001];
36 
37 void dfs(int x)
38 {
39     for(int i=he[x];i;i=e[i].ne)
40     {
41         int t=e[i].to;
42         dfs(t);
43         for(int s=cd[x];s;s--/*int d=cd[t];d;d-- <-- It is wrong*/)
44         {
45             for(int d=min(cd[t],s);d;d--/*int s=cd[x];s>=d;s-- <-- It is wrong*/)
46             {
47                 if(dp[x][s-d]!=-2147483647)dp[x][s]=max(dp[x][s],dp[x][s-d]+dp[t][d]-e[i].v);
48             }
49         }
50     }
51     return;
52 }
53 int main()
54 {
55     scanf("%d%d",&n,&m);
56     for(int i=1;i<=n-m;i++)
57     {
58         int k,tin,vin;
59         scanf("%d",&k);
60         for(int j=1;j<=k;j++)
61         {
62             scanf("%d%d",&tin,&vin);
63             addline(i,tin,vin);
64         }
65     }
66     findd(1);
67     for(int i=1;i<=n;i++)
68     {
69         for(int j=1;j<=cd[i];j++)
70         {
71             dp[i][j]=-2147483647;//初始化
72         }
73     }
74     for(int i=n-m+1;i<=n;i++)
75     {
76         scanf("%d",&dp[i][1]);
77     }
78     dfs(1);
79     for(int i=cd[1];i>=0;i--)
80     {
81         if(dp[1][i]>=0)
82         {
83             printf("%d",i);
84             return 0;
85         }
86     }
87     return 0;
88 }
原文地址:https://www.cnblogs.com/rikurika/p/10009481.html