[POJ1155]TELE(树形背包dp)
看到这道题的第一眼我把题目看成了TLE
这道题是树形背包dp的经典例题
题目描述(大概的):
给你一棵树,每条边有一个cost,每个叶节点有一个earn
要求在earn的和大于等于cost的和的前提下问最多能连接到多少个叶节点
思路:
这道题卡了我0.5month
(因为我太懒了)
核心思路
用dp[x][k]表示x为根的子树里连接到k个叶节点时最大的利润(earn和-cost和)
那么for嵌套顺序应当是
1 for(int s=cd[x];s;s--/*int d=cd[t];d;d-- <-- It is wrong*/) 2 { 3 for(int d=min(cd[t],s);d;d--/*int s=cd[x];s>=d;s-- <-- It is wrong*/) 4 { 5 if(dp[x][s-d]!=-2147483647){dp[x][s]=max(dp[x][s],dp[x][s-d]+dp[t][d]-e[i].v);} 6 } 7 }
没错,也就是外层枚举x的体积,内层枚举t的体积,都是从大到小
至于为什么?
因为t的所有可能体积的dp你只能将其中一个装进x的dp里
听说这叫分组背包?反正我看着也差不多
此后就可以开始快乐地dfs了
(为什么看不少人都用滚动数组呢...dp[3001][3001]也不大啊)
剩下还可以加一些小小的优化
看代码吧
1 #include<cstdio> 2 int max(int a,int b){return a>b?a:b;} 3 int min(int a,int b){return a<b?a:b;} 4 struct sumireko 5 { 6 int to,ne,v; 7 }e[3001]; 8 int he[3001],cnt; 9 void addline(int f,int t,int vin) 10 { 11 e[++cnt].to=t; 12 e[cnt].ne=he[f]; 13 e[cnt].v=vin; 14 he[f]=cnt; 15 return; 16 } 17 18 //bool v[3001]; 19 //写完才注意到是有向边,所以v意义不大,完全可以全部抛弃 20 //但删完之后反而变慢了 21 //越简化越慢也是有毒 22 int n,m,cd[3001];//其中cd数组与下面的findd()一起使用,用于找出每个节点的子树有几个叶节点,从而省点时间(大概) 23 24 int findd(int x) 25 { 26 if(x>(n-m))return cd[x]=1; 27 for(int i=he[x];i;i=e[i].ne) 28 { 29 int t=e[i].to; 30 cd[x]+=findd(t); 31 } 32 return cd[x]; 33 } 34 35 int dp[3001][3001]; 36 37 void dfs(int x) 38 { 39 for(int i=he[x];i;i=e[i].ne) 40 { 41 int t=e[i].to; 42 dfs(t); 43 for(int s=cd[x];s;s--/*int d=cd[t];d;d-- <-- It is wrong*/) 44 { 45 for(int d=min(cd[t],s);d;d--/*int s=cd[x];s>=d;s-- <-- It is wrong*/) 46 { 47 if(dp[x][s-d]!=-2147483647)dp[x][s]=max(dp[x][s],dp[x][s-d]+dp[t][d]-e[i].v); 48 } 49 } 50 } 51 return; 52 } 53 int main() 54 { 55 scanf("%d%d",&n,&m); 56 for(int i=1;i<=n-m;i++) 57 { 58 int k,tin,vin; 59 scanf("%d",&k); 60 for(int j=1;j<=k;j++) 61 { 62 scanf("%d%d",&tin,&vin); 63 addline(i,tin,vin); 64 } 65 } 66 findd(1); 67 for(int i=1;i<=n;i++) 68 { 69 for(int j=1;j<=cd[i];j++) 70 { 71 dp[i][j]=-2147483647;//初始化 72 } 73 } 74 for(int i=n-m+1;i<=n;i++) 75 { 76 scanf("%d",&dp[i][1]); 77 } 78 dfs(1); 79 for(int i=cd[1];i>=0;i--) 80 { 81 if(dp[1][i]>=0) 82 { 83 printf("%d",i); 84 return 0; 85 } 86 } 87 return 0; 88 }