HDU 5834 [树形dp]

/*
题意:n个点组成的树,点和边都有权值,当第一次访问某个点的时候获得利益为点的权值
每次经过一条边,丢失利益为边的权值。问从第i个点出发,获得的利益最大是多少。
输入:
测试样例组数T
n
n个数每个点的权值
n-1条无向边 a b c a和b是点的标号,c是边的权值。
思路:
注意题目只强调是从某个点出发,并不一定要回到该点。
考虑树形DP。
先随便定义一个树根。然后对于某个点,我们需要维护它子树方向的两个值
1.从该点出发,最终要回到该点的利益的最大值
2.从该点出发,最终不必回到该点的利益的最大值,以及最后一次从该点出发的是哪个儿子节点。
3.从该点出发,最终不必回到该点的利益的第二大值,以及最后一次从该点出发的是哪个儿子节点。
先求出回到该点的利益的最大值,然后枚举它的儿子节点,作为最后出去的节点,即该节点不返回。
第二和第三个值把枚举所有儿子作为最后一次出发的情况然后排序记录下前两个就可以了。
以上是第一次DFS所做的工作。
然后对于每个点维护三个值,进行第二次DFS。
1.从该点出发所有方向(子树方向和父亲方向)回到该点的利益的最大值。
2.从该点出发所有方向,不必回到该点的利益的最大值,并保存最后一个出发的儿子节点的标号。
3.从该点出发所有方向,不必回到该点的利益的第二大值,并保存最后一个出发的儿子节点的标号。
第1个值的维护只需要将子树方向和减掉本身所在子树所影响的父亲方向的值加起来即可。
对于第二个和第三个值,我们可以将子树方向上最大的两个值加上父亲方向回来所带来的利益,与
父亲方向作为最后一次出发的点不回来的利益这三个值中求取两个最大的即可。
之所以要维护第二大的值,是因为当父亲节点最终的答案(即可以不回来)的最后一次访问的点恰
好是我们要求取的它的儿子节点的时候,我们可以用次大的值来确定当该儿子节点的父亲方向作为
最后一次访问的方向的时候的利益。
最后所有的ans[i][1]就是答案。
总结“
1.树形DP维护的时候经常需要维护次大值甚至第三大值,这是由于父亲方向的最优问题是否关系到
儿子节点所决定的。
*/








#include<bits/stdc++.h>
#define N 100050
using namespace std;
struct edge{
    int id;
    edge *next;
    long long w;
};
struct st{
    st(int a,long long b){
        id=a;
        ans=b;
    }
    long long ans;
    int id;
};
bool cmp(st a,st b){
    return a.ans>b.ans;
}
int ednum;
edge edges[N*2];
edge *adj[N];
long long v[N],son[N][3],ans[N][3];
int fa[N],id[N][3];
inline void addedge(int a,int b,long long w){
    edge *tmp=&edges[ednum++];
    tmp->w=w;
    tmp->id=b;
    tmp->next=adj[a];
    adj[a]=tmp;
}
void dfs(int pos){
    son[pos][0]+=v[pos];
    for(edge *it=adj[pos];it;it=it->next){
        if(fa[pos]!=it->id){
            fa[it->id]=pos;
            dfs(it->id);
            if(son[it->id][0]-2*it->w>0){
                son[pos][0]+=son[it->id][0]-2*it->w;
            }
        }
    }
    son[pos][1]=son[pos][0];
    vector<st>mv;
    mv.push_back(st(0,son[pos][0]));
    mv.push_back(st(0,son[pos][0]));
    for(edge *it=adj[pos];it;it=it->next){
        if(fa[it->id]==pos){
            if(son[it->id][1]-it->w>0){
                long long tn=son[pos][0]+son[it->id][1]-it->w;
                if(son[it->id][0]-2*it->w>0)tn-=son[it->id][0]-2*it->w;
                if(tn>son[pos][0]){
                    mv.push_back(st(it->id,tn));
                }
            }
        }
    }
    sort(mv.begin(),mv.end(),cmp);
    son[pos][1]=mv[0].ans;
    id[pos][1]=mv[0].id;
    son[pos][2]=mv[1].ans;
    id[pos][2]=mv[1].id;
}
void dfs2(int pos){
    if(pos==1){
        for(int i=0;i<3;i++){
            ans[pos][i]=son[pos][i];
        }
    }
    for(edge *it=adj[pos];it;it=it->next){
        if(fa[it->id]==pos){
            long long tans=ans[pos][0]-2*it->w;
            if(son[it->id][0]-2*it->w>0){
                tans-=son[it->id][0];
                tans+=2*it->w;
            }
            tans=max(0LL,tans);
            ans[it->id][0]=son[it->id][0]+tans;
            vector<st>mv;
            mv.push_back(st(0,ans[it->id][0]));
            mv.push_back(st(0,ans[it->id][0]));
            for(int i=1;i<3;i++){
                mv.push_back(st(id[it->id][i],son[it->id][i]+tans));
            }
            if(id[pos][1]!=it->id){
                long long tm=ans[pos][1];
                tm+=son[it->id][0];
                if(son[it->id][0]-2*it->w>0){
                    tm-=son[it->id][0]-2*it->w;
                }
                mv.push_back(st(id[pos][1],tm-it->w));
            }
            else{
                long long tm=ans[pos][2];
                tm+=son[it->id][0];
                if(son[it->id][0]-2*it->w>0){
                    tm-=son[it->id][0]-2*it->w;
                }
                mv.push_back(st(id[pos][1],tm-it->w));
            }
            sort(mv.begin(),mv.end(),cmp);
            ans[it->id][1]=mv[0].ans;
            id[it->id][1]=mv[0].id;
            ans[it->id][2]=mv[1].ans;
            id[it->id][2]=mv[1].id;
            dfs2(it->id);
        }
    }
}
int main()
{
    int t;
    scanf("%d",&t);
    int cas=0;
    while(t--){
        cas++;
        printf("Case #%d:
",cas);
        int n;
        scanf("%d",&n);
        for(int i=0;i<=n;i++){
            for(int j=0;j<3;j++){
                son[i][j]=id[i][j]=ans[i][j]=0;
            }
            adj[i]=NULL;
            fa[i]=0;
        }
        ednum=0;
        for(int i=1;i<=n;i++)scanf("%lld",v+i);
        for(int i=1;i<n;i++){
            int a,b;long long w;
            scanf("%d%d%lld",&a,&b,&w);
            addedge(a,b,w);
            addedge(b,a,w);
        }
        dfs(1);
        dfs2(1);
        for(int i=1;i<=n;i++){
            printf("%lld
",ans[i][1]);
        }
    }
}
原文地址:https://www.cnblogs.com/tun117/p/5834536.html