luogu P1273 有线电视网

嘟嘟嘟

 

一看完题我就知道是树形dp,状态都想出来了:dp[i][j]表示以 i 为根节点的子树,选了 j 个用户时的最大利润。然后我就卡住了……直到看题解的那一刻……

题解说:树形背包。于是我就知道为什么gg了:我总共就做了一道树形背包。

思路大概是这样的:对于一个节点u,枚举他的儿子v,然后dfs(v)后,就是两层循环的dp。

首先看一下转移方程吧:

  dp[i][u][j] = max(dp[i - 1][u][j], dp[i - 1][u][j - k] + dp[i - 1][v][k] - c[u->v])

i 的范围是u的儿子数量,j 的范围是对于u的第x个儿子v,已经走过的用户数量(想想就明白了),k的范围是以 v 为根的子树中用户的数量。 

最终的答案就是dp[1][i] >= 0中,i 的最大值。

 1 #include<cstdio>
 2 #include<iostream>
 3 #include<cmath>
 4 #include<algorithm>
 5 #include<cstring>
 6 #include<cstdlib>
 7 #include<cctype>
 8 #include<vector>
 9 #include<stack>
10 #include<queue>
11 using namespace std;
12 #define enter puts("") 
13 #define space putchar(' ')
14 #define Mem(a, x) memset(a, x, sizeof(a))
15 #define rg register
16 typedef long long ll;
17 typedef double db;
18 const int INF = 0x3f3f3f3f;
19 const db eps = 1e-8;
20 const int maxn = 3e3 + 5;
21 inline ll read()
22 {
23     ll ans = 0;
24     char ch = getchar(), last = ' ';
25     while(!isdigit(ch)) {last = ch; ch = getchar();}
26     while(isdigit(ch)) {ans = ans * 10 + ch - '0'; ch = getchar();}
27     if(last == '-') ans = -ans;
28     return ans;
29 }
30 inline void write(ll x)
31 {
32     if(x < 0) x = -x, putchar('-');
33     if(x >= 10) write(x / 10);
34     putchar(x % 10 + '0');
35 }
36 
37 int n, m;
38 vector<int> v[maxn], c[maxn];
39 int a[maxn];
40 int dp[maxn][maxn];
41 
42 int dfs(int now)
43 {
44     if(now > n - m) {dp[now][1] = a[now]; return 1;}    //当前节点是一个用户 
45     int siz = 0;
46     for(int i = 0; i < (int)v[now].size(); ++i)
47     {
48         int sz = dfs(v[now][i]); siz += sz;
49         for(int j = siz; j > 0; --j)
50             for(int k = 1; k <= sz; ++k) if(j >= k)
51                 dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[v[now][i]][k] - c[now][i]);
52     }
53     return siz;
54 }
55 
56 int main()
57 {
58     n = read(); m = read();
59     for(int i = 1; i <= n; ++i)
60         for(int j = 1; j <= m; ++j) dp[i][j] = -INF;
61     for(int i = 1; i <= n - m; ++i)
62     {
63         int d = read();
64         for(int j = 1; j <= d; ++j)
65         {
66             int y = read(), co = read();
67             v[i].push_back(y); c[i].push_back(co); 
68         }
69     }
70     for(int i = n - m + 1; i <= n; ++i) a[i] = read();
71     dfs(1);
72     for(int i = m; i > 0; --i) if(dp[1][i] >= 0) {write(i); enter; return 0;}
73     return 0;
74 }
View Code
原文地址:https://www.cnblogs.com/mrclr/p/9642800.html