hiho一下第76周《Suzhou Adventure》

我是菜鸡,我是菜鸡,我是菜鸡。。。。重要的事说三遍

算是第一次做树形dp的题吧,不太难。。

园林构成一棵树,root为1,Hi从root出发,有k个园林必须玩,每个园林游玩后会得到权值w[i],最多玩M个园林。

经过的园林必须玩,问可得到的最大权值和。

题目链接:http://hihocoder.com/problemset/problem/1104

经典的树形dp,dp[i][j],表示以i为根的子树选j个结点可得的最大权值和。

这样dp[root][num] = g[son_num][num-1] + w[root];

g[son_num][num-1]类似于对root的儿子结点的子树做背包,g[i][j] = max( g[i-1][j-p] + dp[child_id][p] );

题解不错:http://hihocoder.com/discuss/question/2743

对于必须玩的k个点,他的处理方法我没有看懂。。。。discussion里的不错。。。游玩结点i,那么其祖先结点fa[i]必要也游玩,将这些点设为must,

将must构成的子树缩成一个点,rebuild a new tree,就成了上面说的最普通的树形dp了。

  1 #include<bits/stdc++.h>
  2 
  3 using namespace std;
  4 const int maxn = 100 + 10;
  5 int n, k, m;
  6 
  7 vector<int> G[maxn], G1[maxn];
  8 int fa[maxn], w[maxn], must[maxn], vis[maxn];
  9 int f[maxn][maxn];
 10 
 11 void init(){
 12     for( int i = 0; i < maxn; ++i ){
 13         G[i].clear();
 14         G1[i].clear();
 15     }
 16     memset( fa, -1, sizeof(fa) );
 17     memset( f, -1, sizeof(f) );
 18     memset( must, 0, sizeof(must) );
 19 }
 20 
 21 void dfs( int u, int pa ){
 22     fa[u] = pa;
 23     for( int i = 0; i < G[u].size(); ++i ){
 24         int v = G[u][i];
 25         if( v == pa )
 26             continue;
 27         dfs( v, u );
 28     }
 29 }
 30 
 31 void dfs1( int u, int pa ){
 32     if( pa != -1 ){
 33         if( vis[pa] && !vis[u] )
 34             G1[1].push_back(u);
 35         else if( !vis[pa] && !vis[u] )
 36             G1[pa].push_back(u);
 37     }
 38 
 39     for( int i = 0; i < G[u].size(); ++i ){
 40         int v = G[u][i];
 41         if( v == pa )
 42             continue;
 43         dfs1( v, u );
 44     }
 45 }
 46 
 47 int dp( int root, int  num ){
 48     if( num == 0 )
 49         return 0;
 50     if( f[root][num] != -1 )
 51         return f[root][num];
 52 
 53     int g[maxn][maxn], son_num = G1[root].size();
 54     memset( g, 0, sizeof(g) );
 55     for( int i = 1; i <= son_num; ++i ){
 56         int child = G1[root][i-1];
 57         for( int j = 0; j < num; ++j ){
 58             for( int p = 0; p <= j; ++p ){
 59                 g[i][j] = max( g[i][j], g[i-1][j-p] + dp( child, p ) );
 60             }
 61         }
 62     }
 63 
 64     f[root][num] = g[son_num][num-1] + w[root];
 65     return f[root][num];
 66 }
 67 
 68 void print( int num ){
 69     for( int i = 0; i < G1[num].size(); ++i )
 70         cout << G1[num][i] << "   ";
 71     cout << endl;
 72 }
 73 
 74 void solve(){
 75     dfs(1, -1);
 76 
 77     //must节点压缩
 78     int ans = 0, cnt = 0;
 79     memset( vis, 0, sizeof(vis) );
 80     for( int i = 1; i <= n; ++i ){
 81         if(must[i]){
 82             int u = i;
 83             while(u != -1 && !vis[u]){
 84                 vis[u] = 1;
 85                 cnt++, ans += w[u];
 86                 u = fa[u];
 87             }
 88         }
 89     }
 90     //cout << "ans: " << ans << endl;
 91 
 92     if( cnt > m ){
 93         printf("-1
");
 94         return;
 95     }
 96 
 97     //rebuild tree
 98     dfs1( 1, -1 );
 99 
100     w[1] = ans;
101     //cout << "m-cnt: " << m - cnt << endl;
102     ans = dp( 1, m-cnt+1 );
103     //cout << "f: " << f[8][3] << endl;
104 
105     printf("%d
", ans);
106 }
107 
108 int main(){
109     //freopen("1.in", "r", stdin);
110     init();
111     scanf("%d%d%d", &n, &k, &m);
112     for( int i = 1; i <= n; ++i ){
113         scanf("%d", &w[i]);
114     }
115 
116     int t;
117     for( int i = 1; i <= k; ++i ){
118         scanf("%d", &t);
119         must[t] = 1;
120     }
121 
122     int a, b;
123     for( int i = 1; i <= n-1; ++i ){
124         scanf("%d%d", &a, &b);
125         G[a].push_back(b), G[b].push_back(a);
126     }
127 
128     solve();
129 
130     return 0;
131 }
View Code

当然,也可以不需要g数组,直接当一维背包来做

dp[root][x] = max( dp[root][x], dp[root][x-y] + dp[child_id][y] );

  1 #include<bits/stdc++.h>
  2 
  3 using namespace std;
  4 const int maxn = 100 + 10;
  5 int n, k, m;
  6 
  7 vector<int> G[maxn], G1[maxn];
  8 int fa[maxn], w[maxn], must[maxn], vis[maxn];
  9 int f[maxn][maxn];
 10 
 11 void init(){
 12     for( int i = 0; i < maxn; ++i ){
 13         G[i].clear();
 14         G1[i].clear();
 15     }
 16     memset( fa, -1, sizeof(fa) );
 17     memset( f, -1, sizeof(f) );
 18     memset( must, 0, sizeof(must) );
 19 }
 20 
 21 void dfs( int u, int pa ){
 22     fa[u] = pa;
 23     for( int i = 0; i < G[u].size(); ++i ){
 24         int v = G[u][i];
 25         if( v == pa )
 26             continue;
 27         dfs( v, u );
 28     }
 29 }
 30 
 31 void dfs1( int u, int pa ){
 32     if( pa != -1 ){
 33         if( vis[pa] && !vis[u] )
 34             G1[1].push_back(u);
 35         else if( !vis[pa] && !vis[u] )
 36             G1[pa].push_back(u);
 37     }
 38 
 39     for( int i = 0; i < G[u].size(); ++i ){
 40         int v = G[u][i];
 41         if( v == pa )
 42             continue;
 43         dfs1( v, u );
 44     }
 45 }
 46 
 47 void dp(int root, int pa){
 48     f[root][1] = w[root];
 49     for( int i = 0; i < G1[root].size(); ++i ){
 50         int v = G1[root][i];
 51         if( v == pa )
 52             continue;
 53         dp( v, root );
 54         for( int x = m; x >= 1; --x ){
 55             for( int y = 0; y < x; ++y ){
 56                 f[root][x] = max( f[root][x], f[root][x-y] + f[v][y] );
 57             }
 58         }
 59     }
 60 }
 61 
 62 void solve(){
 63     dfs(1, -1);
 64 
 65     //must½ÚµãѹËõ
 66     int ans = 0, cnt = 0;
 67     memset( vis, 0, sizeof(vis) );
 68     for( int i = 1; i <= n; ++i ){
 69         if(must[i]){
 70             int u = i;
 71             while(u != -1 && !vis[u]){
 72                 vis[u] = 1;
 73                 cnt++, ans += w[u];
 74                 u = fa[u];
 75             }
 76         }
 77     }
 78     //cout << "ans: " << ans << endl;
 79 
 80     if( cnt > m ){
 81         printf("-1
");
 82         return;
 83     }
 84 
 85     //rebuild tree
 86     dfs1( 1, -1 );
 87 
 88     w[1] = ans;
 89     dp(1, -1);
 90     printf("%d
", f[1][m-cnt+1]);
 91 }
 92 
 93 int main(){
 94     //freopen("1.in", "r", stdin);
 95     init();
 96     scanf("%d%d%d", &n, &k, &m);
 97     for( int i = 1; i <= n; ++i ){
 98         scanf("%d", &w[i]);
 99     }
100 
101     int t;
102     for( int i = 1; i <= k; ++i ){
103         scanf("%d", &t);
104         must[t] = 1;
105     }
106 
107     int a, b;
108     for( int i = 1; i <= n-1; ++i ){
109         scanf("%d%d", &a, &b);
110         G[a].push_back(b), G[b].push_back(a);
111     }
112 
113     solve();
114 
115     return 0;
116 }
View Code
原文地址:https://www.cnblogs.com/zhazhalovecoding/p/5059242.html