HDU 4616 Game 树形dp

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=4616

Game

Time Limit: 3000/1000 MS (Java/Others)
Memory Limit: 65535/32768 K (Java/Others)
#### 问题描述 > Nowadays, there are more and more challenge game on TV such as 'Girls, Rush Ahead'. Now, you participate int a game like this. There are N rooms. The connection of rooms is like a tree. In other words, you can go to any other room by one and only one way. There is a gift prepared for you in Every room, and if you go the room, you can get this gift. However, there is also a trap in some rooms. After you get the gift, you may be trapped. After you go out a room, you can not go back to it any more. You can choose to start at any room ,and when you have no room to go or have been trapped for C times, game overs. Now you would like to know what is the maximum total value of gifts you can get. #### 输入 > The first line contains an integer T, indicating the number of testcases. >   For each testcase, the first line contains one integer N(2 <= N <= 50000), the number rooms, and another integer C(1 <= C <= 3), the number of chances to be trapped. Each of the next N lines contains two integers, which are the value of gift in the room and whether have trap in this rooom. Rooms are numbered from 0 to N-1. Each of the next N-1 lines contains two integer A and B(0 <= A,B <= N-1), representing that room A and room B is connected. >   All gifts' value are bigger than 0.

输出

  For each testcase, output the maximum total value of gifts you can get.

样例输入

2
3 1
23 0
12 0
123 1
0 2
2 1
3 2
23 0
12 0
123 1
0 2
2 1

样例输出

146
158

题意

给你一颗树,每个点有点权,同时一些点还有陷阱(到这个点获得价值后会掉到陷阱里),如果你掉到陷阱k次或无路可走会立马退出,问从任意一点出发,退出时能获得的最大价值,每个点只能走一次(不能走回头路)

题解

首先预处理出两个东西:
1、从某个叶子走到u,掉进陷阱j次的能获得的最大值,次大值。
2、从u走到某个叶子的,掉进陷阱j次能获得的最大值,次大值。
然后枚举每条边,把链拆成两块,考虑左右两块的所有组合。

#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<ctime>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
using namespace std;
#define X first
#define Y second
#define mkp make_pair
#define lson (o<<1)
#define rson ((o<<1)|1)
#define mid (l+(r-l)/2)
#define sz() size()
#define pb(v) push_back(v)
#define all(o) (o).begin(),(o).end()
#define clr(a,v) memset(a,v,sizeof(a))
#define bug(a) cout<<#a<<" = "<<a<<endl
#define rep(i,a,b) for(int i=a;i<(b);i++)
#define scf scanf
#define prf printf

typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
typedef vector<pair<int,int> > VPII;

const int INF=0x3f3f3f3f;
const LL INFL=0x3f3f3f3f3f3f3f3fLL;
const double eps=1e-8;
const double PI = acos(-1.0);

//start----------------------------------------------------------------------

const int maxn=50030;

int n,c;
LL val[maxn];
int tra[maxn];

struct Edge {
    int v,ne;
    Edge(int v,int ne):v(v),ne(ne) {}
    Edge() {}
} egs[maxn*2];

int head[maxn],tot;

void addEdge(int u,int v) {
    egs[tot]=Edge(v,head[u]);
    head[u]=tot++;
}

///dp[u][j][0]表示在子树中以u为起点,机会为j次能够获得的最大值,dp[u][j][1]为对应的次大值,id[u]记录最大值的更新方向
///dp2[u][j][0]表示在子树中以u为终点,机会为j次能够获得的最大值,dp2[u][j][1]为对应的次大值,id2[u]记录最大值的更新方向
LL dp[maxn][5][2],dp2[maxn][5][2];
int id[maxn][5],id2[maxn][5];
void dfs(int u,int fa) {
    if(tra[u]) dp[u][1][0]=dp2[u][1][0]=val[u];
    else dp[u][0][0]=dp2[u][0][0]=val[u];
    bool child=false;
    for(int p=head[u]; p!=-1; p=egs[p].ne) {
        Edge& e=egs[p];
        if(e.v==fa) continue;
        child=true;
        dfs(e.v,u);
        int v=e.v;
        ///以u为终点
        for(int i=0; i<=c; i++) {
            int t=tra[u];
            if(dp[u][i+t][0]<dp[v][i][0]+val[u]) {
                dp[u][i+t][1]=dp[u][i+t][0];
                dp[u][i+t][0]=dp[v][i][0]+val[u];
                id[u][i+t]=v;
            } else if(dp[u][i+t][1]<dp[v][i][0]+val[u]) {
                dp[u][i+t][1]=dp[v][i][0]+val[u];
            }
        }
        ///以u为起点,这和上面的区别体现在限制为1的初始化上
        ///如果当前点有限制,那么如果你限制为1,明显是会直接停下来的,
        ///不可能再由都没有经过限制的儿子那里转移过来。
        if(tra[u]) {
            dp2[u][1][0]=val[u];
            for(int i=2; i<=c; i++) {
                if(dp2[u][i][0]<dp2[v][i-1][0]+val[u]) {
                    dp2[u][i][1]=dp2[u][i][0];
                    dp2[u][i][0]=dp2[v][i-1][0]+val[u];
                    id2[u][i]=v;
                } else if(dp2[u][i][1]<dp2[v][i-1][0]+val[u]) {
                    dp2[u][i][1]=dp2[v][i-1][0]+val[u];
                }
            }
        } else {
            for(int i=0; i<=c; i++) {
                if(dp2[u][i][0]<dp2[v][i][0]+val[u]) {
                    dp2[u][i][1]=dp2[u][i][0];
                    dp2[u][i][0]=dp2[v][i][0]+val[u];
                    id2[u][i]=v;
                } else if(dp2[u][i][1]<dp2[v][i][0]+val[u]) {
                    dp2[u][i][1]=dp2[v][i][0]+val[u];
                }
            }
        }
    }
}

///枚举每条边,吧链拆分成两部分,利用dp和dp2来更新答案
LL ans;
void dfs2(int u,int fa) {
    for(int p=head[u]; p!=-1; p=egs[p].ne) {
        Edge& e=egs[p];
        if(e.v==fa) continue;
        dfs2(e.v,u);
        int v=e.v;
        LL u1,u0;
        for(int i=0; i<=c; i++) {
            if(id[u][i]==v) u0=dp[u][i][1];
            else u0=dp[u][i][0];
            if(id2[u][i]==v) u1=dp2[u][i][1];
            else u1=dp2[u][i][0];
            for(int j=0; j+i<=c; j++) {
                if(i<c) ans=max(ans,u0+dp2[e.v][j][0]);
                if(j<c) ans=max(ans,u1+dp[e.v][j][0]);
                if(i+j<c) ans=max(ans,u0+dp[e.v][j][0]);
            }
        }
    }
}


void init() {
    clr(head,-1);
    clr(dp,0),clr(dp2,0);
    clr(id,-1),clr(id2,-1);
    tot=0;
}

int main() {
    int tc;
    scf("%d",&tc);
    while(tc--) {
        scf("%d%d",&n,&c);
        init();
        rep(i,0,n) scf("%lld%d",&val[i],&tra[i]);
        rep(i,0,n-1) {
            int u,v;
            scf("%d%d",&u,&v);
            addEdge(u,v);
            addEdge(v,u);
        }

        dfs(0,-1);
        ans=0;
        dfs2(0,-1);

        prf("%lld
",ans);

    }
    return 0;
}

//end-----------------------------------------------------------------------

/*
1
3 1
23 1
12 0
123 1
0 2
2 1
*/

来个精炼版的:

#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<ctime>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
using namespace std;
#define X first
#define Y second
#define mkp make_pair
#define lson (o<<1)
#define rson ((o<<1)|1)
#define mid (l+(r-l)/2)
#define sz() size()
#define pb(v) push_back(v)
#define all(o) (o).begin(),(o).end()
#define clr(a,v) memset(a,v,sizeof(a))
#define bug(a) cout<<#a<<" = "<<a<<endl
#define rep(i,a,b) for(int i=a;i<(b);i++)
#define scf scanf
#define prf printf

typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
typedef vector<pair<int,int> > VPII;

const int INF=0x3f3f3f3f;
const LL INFL=0x3f3f3f3f3f3f3f3fLL;
const double eps=1e-8;
const double PI = acos(-1.0);

//start----------------------------------------------------------------------

const int maxn=50505;

int n,m;
LL gif[maxn];
int tra[maxn];
VI G[maxn];

///dp[u][j][0]表示以u为终点的,掉过j次陷阱的能获得的最大值
///dp[u][j][1]表示以u为起点的,掉过j次陷阱的能获得的最大值(注意,遇到第j个陷阱的时候回马上停下来,所以更新与上面的有所不同
LL dp[maxn][5][2];
LL ans;
void dfs(int u,int fa) {
    clr(dp[u],0);
    if(tra[u]) {
        dp[u][1][0]=dp[u][1][1]=gif[u];
    } else {
        dp[u][0][0]=dp[u][0][1]=gif[u];
    }
    rep(i,0,G[u].sz()) {
        int v=G[u][i];
        if(v==fa) continue;
        dfs(v,u);

        ///边搜边枚举
        for(int j=0; j<=m; j++) {
            for(int k=0; k+j<=m; k++) {
                if(j<m) ans=max(ans,dp[u][j][0]+dp[v][k][1]);
                if(k<m) ans=max(ans,dp[u][j][1]+dp[v][k][0]);
                if(j+k<m) ans=max(ans,dp[u][j][0]+dp[v][k][0]);
            }
        }

        for(int j=0; j<=m; j++) {
            dp[u][j+tra[u]][0]=max(dp[u][j+tra[u]][0],dp[v][j][0]+gif[u]);
            if(tra[u]&&j==0) dp[u][1][1]=gif[u];
            else dp[u][j+tra[u]][1]=max(dp[u][j+tra[u]][1],dp[v][j][1]+gif[u]);
        }
    }

}

void init() {
    for(int i=0; i<n; i++) G[i].clear();
}

int main() {
    int tc;
    scf("%d",&tc);
    while(tc--) {
        scf("%d%d",&n,&m);
        init();
        rep(i,0,n) scf("%lld%d",&gif[i],&tra[i]);
        rep(i,0,n-1) {
            int u,v;
            scf("%d%d",&u,&v);
            G[u].pb(v);
            G[v].pb(u);
        }
        ans=0;
        dfs(0,-1);
        prf("%lld
",ans);
    }
    return 0;
}

//end-----------------------------------------------------------------------

/*
2
3 1
23 1
12 1
123 0
0 2
2 1
*/
原文地址:https://www.cnblogs.com/fenice/p/5959805.html