HDU4616 树形DP+三次深搜

这题和之前那个HDU2616有着奇妙的异曲同工之处。。都是要求某个点能够到达的最大权重的地方。。。

但是,这题加了个限制,要求最多只能够踩到C个陷阱,一单无路可走或者命用光了,就地开始清算总共得分之和。

于是首先考虑,C的大小只有4,那么可以进行非常方便的状态转移,即将之前2616中的各个矩阵都加一维,设定为走到这一步的时候,可以踩得陷阱的个数——如果可以踩得陷阱的个数是0就意味着不能够行走,于是直接规定,所有是0条命的都自动为0。

之后按照上题的方式进行列举,有所不同的是,需要考虑下数组传递参数的方式,我的做法是将临时变量开到全局空间,这样就就可以保证不会爆栈什么的了。

考虑每个点,如果该电有陷阱,那么所有的转移,都必须按是少了一条命的结果,否则就是直接转移,同事我们认为,由于没有点数小于等于0的点,所以,必然可以得出,在同一个点上,当两个人的命数量不相等时,必然会有命多的“最大得分不小于命少的最大得分”。于是只需要普通的更新就是了。

#include<iostream>
#include<vector>
#include<string.h>
#include<stdio.h>
using namespace std;

const long long MAXN = 50233;
vector<int>G[MAXN];
const long long LIMIT = 4;
long long child[MAXN][LIMIT];
long long child_left[MAXN][LIMIT];
long long child_right[MAXN][LIMIT];
long long tmp[MAXN][LIMIT];
bool trap[MAXN];
long long arr[MAXN];
long long n, c;

long long max(long long a, long long b)
{
    return  a > b ? a : b;
}

void get_child(int now, int last)
{
    int len = G[now].size();
    for (int i = 0; i<LIMIT; ++i)child[now][i] = 0;
    for (int i = 0; i<len; ++i)
    {
        int tar = G[now][i];
        if (tar == last)continue;

        get_child(tar, now);
        if (trap[now])for (int i = 1; i<LIMIT; ++i)child[now][i] = max(child[now][i], child[tar][i - 1]);
        else for (int i = 1; i<LIMIT; ++i)child[now][i] = max(child[now][i], child[tar][i]);
    }
    for (int i = 1; i<LIMIT; ++i)child[now][i] += arr[now];
}
void get_left(int now, int last)
{
    int len = G[now].size();
    if (last == -1)
    {
        memset(child_left[now], 0, sizeof(child_left));
        memset(tmp[now], 0, sizeof(tmp[now]));
    }
    else
    {
        if (trap[now]) for (int i = 1; i<LIMIT; ++i) child_left[now][i] = tmp[now][i] = tmp[last][i - 1];
        else for (int i = 0; i<LIMIT; ++i)child_left[now][i] = tmp[now][i] = tmp[last][i];

    }    for (int i = 1; i<LIMIT; ++i)child_left[now][i] += arr[now];


    long long ttmp[LIMIT]; memset(ttmp, 0, sizeof(ttmp));
    for (int i = 0; i<len; ++i)
    {
        int tar = G[now][i];
        if (tar == last)continue;
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] = max(tmp[now][j], ttmp[j]);
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] += arr[now];
        get_left(tar, now);
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] -= arr[now];
        if (trap[now])    for (int j = 1; j<LIMIT; ++j)ttmp[j] = max(ttmp[j], child[tar][j - 1]);
        else for (int j = 1; j<LIMIT; ++j)    ttmp[j] = max(ttmp[j], child[tar][j]);
    }

}
void get_right(int now, int last)
{
    int len = G[now].size();

    if (last == -1)
    {
        memset(child_right[now], 0, sizeof(child_right));
        memset(tmp[now], 0, sizeof(tmp[now]));
    }
    else
    {
        if (trap[now]) for (int i = 1; i<LIMIT; ++i) child_right[now][i] = tmp[now][i] = tmp[last][i - 1];
        else for (int i = 0; i<LIMIT; ++i)child_right[now][i] = tmp[now][i] = tmp[last][i];

    }for (int i = 1; i<LIMIT; ++i)child_right[now][i] += arr[now];


    long long ttmp[LIMIT]; memset(ttmp, 0, sizeof(ttmp));
    for (int i = len - 1; i >= 0; --i)
    {
        int tar = G[now][i];
        if (tar == last)continue;
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] = max(tmp[now][j], ttmp[j]);
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] += arr[now];
        get_right(tar, now);
        for (int j = 1; j<LIMIT; ++j)tmp[now][j] -= arr[now];
        if (trap[now])    for (int j = 1; j<LIMIT; ++j)ttmp[j] = max(ttmp[j], child[tar][j - 1]);
        else for (int j = 1; j<LIMIT; ++j)    ttmp[j] = max(ttmp[j], child[tar][j]);
    }

}


void init()
{
    cin >> n >> c;
    for (int i = 0; i <= n; ++i)G[i].clear();
    for (int i = 0; i<n; ++i)    cin >> arr[i] >> trap[i];
    for (int i = 1; i<n; ++i)
    {
        int a, b; cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }get_child(0, -1);
    get_left(0, -1);
    get_right(0, -1);
    long long ans = 0;
    for (int i = 0; i<n; ++i)
    {
        ans = max(ans, max(max(child_left[i][c], child_right[i][c]), child[i][c]));
    }
    cout << ans << endl;
}
int main()
{
    cin.sync_with_stdio(false);
    int tt; cin >> tt;
    while (tt--)init();

    return 0;
}
原文地址:https://www.cnblogs.com/rikka/p/7632933.html