hdu 4616 Game

http://acm.hdu.edu.cn/showproblem.php?pid=4616

要记录各种状态的段  a[2][4]

a[0][j]表示以trap为起点一共有j个trap的最优值 

a[1][j]表示不以trap为起点一共有j个trap的最优值

dp[x][i][j] 表示以x为根节点的子树从各个叶子到x节点的各状态最优值

每到一个节点 要枚举经过此节点的所有符合要求的段中最优的(需要合并段)

代码:

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<map>
#include<stack>
#include<vector>
#include<algorithm>
#include<queue>
#include<bitset>
#include<deque>
#include<numeric>

#pragma comment(linker, "/STACK:1024000000,1024000000")

using namespace std;

typedef long long ll;
typedef unsigned int uint;
typedef pair<int,int> pp;
const double eps=1e-9;
const int INF=0x3f3f3f3f;
const ll MOD=1000000007;
const int N=100005;
int head[N],I;
struct node
{
    int j,next;
}edge[N*2];
int value[N],trap[N];
int dp[N][2][4];
int ans,C;
void add(int i,int j)
{
    edge[I].j=j;
    edge[I].next=head[i];
    head[i]=I++;
}
void init(int n)
{
    for(int i=0;i<n;++i)
    scanf("%d %d",&value[i],&trap[i]);
    memset(head,-1,sizeof(head));I=0;
    for(int i=1;i<n;++i)
    {
        int l,r;
        scanf("%d %d",&l,&r);
        add(l,r);
        add(r,l);
    }
}
void copyArr(int (*b)[4],int (*a)[4])
{
    for(int i=0;i<2;++i)
    for(int j=0;j<4;++j)
    b[i][j]=a[i][j];
}
void clArr(int (*b)[4])
{
    for(int i=0;i<2;++i)
    for(int j=0;j<4;++j)
    b[i][j]=-1;
    b[0][0]=b[1][0]=0;
}
void update(int (*b)[4],int x)
{
    if(trap[x]==0)
    {
        for(int i=0;i<2;++i)
        for(int j=0;j<4;++j)
        if(b[i][j]!=-1)
        b[i][j]+=value[x];
        b[0][0]=0;
    }else
    {
        for(int i=0;i<2;++i)
        for(int j=3;j>0;--j)
        {
            if(b[i][j-1]!=-1)
            b[i][j]=b[i][j-1]+value[x];
        }
        b[0][0]=0;
        b[1][0]=0;
    }
}
void print(int (*b)[4])
{
    for(int i=0;i<2;++i)
    {
        for(int j=0;j<4;++j)
        printf("%4d ",b[i][j]);printf("
");
    }printf("
");
}
void findAns(int (*b)[4],int (*v1)[4],int (*v2)[4],int x)
{
    int c=C-trap[x];
    int tmp=0;
    for(int i=0;i<2;++i)
    for(int j=0;j<4;++j)
    {
        for(int l=0;l<2;++l)
        for(int r=0;r<4;++r)
        {
            if(j+r>c) continue;
            if(j+r==c)
            {
                if(i+l==2) continue;
                if(i!=l)
                {
                    if(i==0&&j==0) continue;
                    if(l==0&&r==0) continue;
                }
            }
            if(v1[l][r]!=b[l][r])
            tmp=max(tmp,max(0,v1[l][r])+max(0,b[i][j]));
            else
            tmp=max(tmp,max(0,v2[l][r])+max(0,b[i][j]));
        }
    }
    ans=max(ans,tmp+value[x]);
}
void dfs(int pre,int x,int (*a)[4])
{
    int b[2][4];
    copyArr(b,a);
    update(b,x);
    int v1[2][4],v2[2][4];
    clArr(v1);clArr(v2);
    for(int t=head[x];t!=-1;t=edge[t].next)
    {
        int l=edge[t].j;
        if(l==pre) continue;
        dfs(x,l,b);
        for(int i=0;i<2;++i)
        for(int j=0;j<4;++j)
        {
            v2[i][j]=max(v2[i][j],dp[l][i][j]);
            if(v1[i][j]<v2[i][j])
            swap(v1[i][j],v2[i][j]);
        }
    }
    copyArr(dp[x],v1);
    update(dp[x],x);
    for(int i=0;i<2;++i)
    for(int j=0;j<4;++j)
    {
        v2[i][j]=max(v2[i][j],a[i][j]);
        if(v1[i][j]<v2[i][j])
        swap(v1[i][j],v2[i][j]);
    }
    findAns(a,v1,v2,x);
    for(int t=head[x];t!=-1;t=edge[t].next)
    {
        int l=edge[t].j;
        if(l==pre) continue;
        findAns(dp[l],v1,v2,x);
    }
}
int main()
{
    //freopen("data.in","r",stdin);
    int T;
    scanf("%d",&T);
    while(T--)
    {
        int n;
        scanf("%d %d",&n,&C);
        init(n);
        int a[2][4];
        clArr(a);
        memset(dp,-1,sizeof(dp));
        ans=0;
        dfs(-1,0,a);
        printf("%d
",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/liulangye/p/3217417.html