HDU 3899 简单树形DP

题意:一棵树,给出每个点的权值和每条边的长度,

点j到点i的代价为点j的权值乘以连接i和j的边的长度。求点x使得所有点到点x的代价最小,输出

虽然还是不太懂树形DP是什么意思,先把代码贴出来把。

这道题目的做法是:先进行一次DFS,以每个节点为根,求出它下面节点到它的数量和。

再进行一次DFS,以每个节点为根,求出它下面节点到它的花费总和。

source code:

#pragma comment(linker, "/STACK:16777216") //for c++ Compiler
#include <stdio.h>
#include <iostream>
#include <cstring>
#include <cmath>
#include <stack>
#include <queue>
#include <vector>
#include <algorithm>
#define ll long long
#define Max(a,b) (((a) > (b)) ? (a) : (b))
#define Min(a,b) (((a) < (b)) ? (a) : (b))
#define Abs(x) (((x) > 0) ? (x) : (-(x)))
using namespace std;

const int INF = 0x3f3f3f3f;
const int MAXN = 100001;

struct edge{
    int v, w;
    edge(){}
    edge(int a, int b) : v(a), w(b){}
};

vector <edge> node[MAXN];
int n,sum[MAXN],a[MAXN];
//a[MAXN]表示每一个节点的权值,sum[MAXN]表示子树的权值和
ll dist[MAXN];
bool vis[MAXN];

void init(){
    for(ll i = 0; i <= n; ++i)
        node[i].clear();
    memset(vis, 0, sizeof(vis));
    memset(dist,0,sizeof(dist));
}

void dfs(ll u,ll dis){
    vis[u] = true;
    dist[1] += dis * a[u];
    sum[u] = a[u];
    ll size = node[u].size();
    for(ll i = 0; i < size; ++i){
        ll v = node[u][i].v;
        if(vis[v]) continue;
        dfs(v, dis + node[u][i].w);
        sum[u] += sum[v];
    }
}

void dfs1(ll u){
    vis[u] = true;
    ll size = node[u].size();
    for(ll i = 0; i < size; ++i){
        ll v = node[u][i].v;
        ll w = node[u][i].w;
        if(vis[v]) continue;
        dist[v] = dist[u] - sum[v] * w + (sum[1] - sum[v]) * w;
        //dist[v] = dist[u] + (sum[1] - 2 * sum[v]) * w;
        dfs1(v);
    }
}

int main(){
    ll i, j, x, y, w;
    while(EOF != scanf("%I64d",&n)){
        init();
        for(i = 1; i <= n; ++i)   scanf("%I64d",&a[i]);
        for(i = 1; i < n; ++i){
            scanf("%I64d %I64d %I64d",&x,&y,&w);
            node[x].push_back(edge(y,w));
            node[y].push_back(edge(x,w));
        }
        dfs(1,0);
        memset(vis,0,sizeof(vis));
        dfs1(1);
        ll ans = dist[1];
        for(i = 2; i <= n; ++i){
            ans = min(ans, dist[i]);
        }
        cout << ans << endl;
        //printf("%I64d
",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/wushuaiyi/p/4090661.html