[hiho1035] 自驾旅行III

题意:给你一棵N个结点的树,现有一个人和一辆车,每条边有人走和车走的两个权值,给出m个关键点,问走完所有关键点的最小代价

题解:
树形DP
dp[u][0]表示人下,人上,dp[u][0]=dp[v][0]
dp[u][1]表示人下,人不一定上,dp[u][1]=∑(k-1)(dp[v][0]+2*w1)+min(dp[v][1]+w1)
dp[u][2]表示人车下,人车上,dp[u][2]=∑min(dp[v][1]+2*w1,dp[v][2]+2*w2)=t
dp[u][3]表示人车下,人上车不一定上,dp[u][3]=∑(k-1)(dp[v][2]+2*w2)+min(dp[v][3]+w1+w2)
dp[u][4]表示人车下,人车不一定上,两种情况:
1、最后一次有车,dp[u][4]= ∑(k-1)t+min(dp[v][1]+w1,dp[v][4]+w2)
2、最后一次无车,dp[u][4]= ∑(k-2)t+dp[v][3]+w1+w2+dp[v][1]+w1 (这种情况两个儿子不能在同一棵子树上)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;

const int N = 1000010;

int n,m,e_num;
int nxt[N*2],to[N*2],w1[N*2],w2[N*2],h[N];
ll inf=1ll<<62,dp[N][5];
bool vis[N];

int gi() {
  int x=0,o=1; char ch=getchar();
  while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
  if(ch=='-') o=-1,ch=getchar();
  while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
  return o*x;
}

void add(int x, int y, int z1, int z2) {
  nxt[++e_num]=h[x],to[e_num]=y,w1[e_num]=z1,w2[e_num]=z2,h[x]=e_num;
}

void dfs(int u, int fa) {
  ll tmp1=0,tmp2=0,tmp3=0,tmp=0;
  ll a,b,a1=inf,a2=inf,b1=inf,b2=inf,ta=-1,tb=-1,t;
  for(int i=h[u]; i; i=nxt[i]) {
    int v=to[i];
    if(v==fa) continue;
    dfs(v,u);
    if(!vis[v]) continue;
    vis[u]=1;
    dp[u][0]+=dp[v][0]+2*w1[i];//只有人走,人必须回
    t=min(dp[v][0]+2*w1[i],dp[v][2]+2*w2[i]);
    dp[u][2]+=t;//人和车都走,都回
    tmp1=min(tmp1,dp[v][1]-dp[v][0]-w1[i]);
    tmp2=min(tmp2,dp[v][3]+w1[i]+w2[i]-t);
    tmp3=min(tmp3,min(dp[v][1]+w1[i],dp[v][4]+w2[i])-t);
    a=dp[v][1]+w1[i]-t,b=w2[i]+dp[v][3]+w1[i]-t;
    if(a<a1) a2=a1,a1=a,ta=v;
    else if(a<a2) a2=a;
    if(b<b1) b2=b1,b1=b,tb=v;
    else if(b<b2) b2=b;
    if(ta!=-1 && tb!=-1) {
      if(ta!=tb) tmp=a1+b1;
      else tmp=min(a1+b2,a2+b1);//不能在同一个子树内
    }
  }
  dp[u][1]=dp[u][0]+tmp1;
  dp[u][3]=dp[u][2]+tmp2;
  dp[u][4]=dp[u][2]+min(tmp3,tmp);
  dp[u][4]=min(dp[u][4],dp[u][3]);
}

int main() {
  n=gi();
  for(int i=1; i<n; i++) {
    int x=gi(),y=gi(),z1=gi(),z2=gi();
    add(x,y,z1,z2),add(y,x,z1,z2);
  }
  m=gi();
  for(int i=1; i<=m; i++) vis[gi()]=1;
  dfs(1,0);
  printf("%lld", dp[1][4]);
  return 0;
}
原文地址:https://www.cnblogs.com/HLXZZ/p/7544249.html