HDU 1561 The more, The Better

树形DP。树上背包AC了......每一个节点做一次背包。dp[id][X] 表示 编号为id的节点的子树上 选取X个节点 获得的最大价值

#include<cstdio>
#include<cstring>
#include<cmath>
#include<ctime>
#include<vector>
#include<algorithm>
using namespace std;

const int maxn = 200 + 10;
int n, m;
int val[maxn];
int dp[maxn][maxn];
int flag[maxn], tmp[maxn];
int w[maxn*maxn], c[maxn*maxn];
int u;
vector<int>tree[maxn];

void init()
{
    memset(dp, -1, sizeof dp);
    memset(val, 0, sizeof val);
    for (int i = 0; i <= n; i++) dp[i][0] = 0;
    for (int i = 0; i <= n; i++) tree[i].clear();
}

void read()
{
    for (int i = 1; i <= n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        val[i] = b;
        tree[a].push_back(i);
    }
}

void dfs(int now)
{
    if (!tree[now].size())
    {
        dp[now][1] = val[now];
        return;
    }

    for (int i = 0; i<tree[now].size(); i++)
        dfs(tree[now][i]);

    memset(flag, -1, sizeof flag);
    flag[0] = 0;

    if (now != 0)
    {
        for (int i = 0; i<tree[now].size(); i++)
        {
            u = 0;
            for (int j = 1; j <= m; j++)
                if (dp[tree[now][i]][j] != -1)
                    w[u] = dp[tree[now][i]][j], c[u++] = j;

            memset(tmp, -1, sizeof tmp);
            for (int k = 0; k<u; k++)
                for (int j = m - 1; j >= c[k]; j--)
                    if (flag[j - c[k]] != -1)
                        if (flag[j - c[k]] + w[k]>flag[j] && flag[j - c[k]] + w[k]>tmp[j])
                            tmp[j] = flag[j - c[k]] + w[k];

            for (int j = 1; j <= m - 1; j++) if (tmp[j]>flag[j]) flag[j] = tmp[j];
        }
        for (int i = 1; i <= m; i++) if (flag[i - 1] != -1) dp[now][i] = flag[i - 1] + val[now];
    }

    else
    {
        for (int i = 0; i<tree[now].size(); i++)
        {
            u = 0;
            for (int j = 1; j <= m; j++)
                if (dp[tree[now][i]][j] != -1)
                    w[u] = dp[tree[now][i]][j], c[u++] = j;

            memset(tmp, -1, sizeof tmp);
            for (int k = 0; k<u; k++)
                for (int j = m; j >= c[k]; j--)
                    if (flag[j - c[k]] != -1)
                        if (flag[j - c[k]] + w[k]>flag[j] && flag[j - c[k]] + w[k]>tmp[j])
                            tmp[j] = flag[j - c[k]] + w[k];

            for (int j = 1; j <= m; j++) if (tmp[j]>flag[j]) flag[j] = tmp[j];
        }
        for (int i = 1; i <= m; i++)  dp[now][i] = flag[i];
    }
}

void work()
{
    dfs(0);
    printf("%d
", dp[0][m]);
}

int main()
{
    while (~scanf("%d%d", &n, &m))
    {
        if (!n&&!m) break;
        init();
        read();
        work();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/zufezzt/p/5182113.html