Hihocoder 1035 [树形dp]

/*
题意:
不要低头,不要放弃,不要气馁,不要慌张。
PS:人生第一道自己独立做出来的树形dp...
给一棵树,标号1到n,每条边有两个权值,步行时间和驾车时间。车在1号点。
给m个必须访问的关键点,求从1号点出发,访问所有关键点一遍的最小时间。
注意车可以停在任意地方,但是只有1号点有一辆车,人最后也可以停留在任意点。
思路:
1.子树方向(注意dp1 dp2 dp4都是保证人一定要返回该点的最优解)
dp1代表该点起始有车,并且从该点出发访问完该点子树上所有的关键点车和人都返回的该点的最优解。
注意dp1并不是访问每个子树都要驾车,只要保证该子树访问完人和车都在该点就可以,也就是也可以步行出发并返回。
dp2代表该点起始有车,从该点出发人必须返回,但是车不一定返回的最优解。
dp4代表该点起始无车,从该点出发步行并且返回该点的最优解。
2.父亲节点方向
显而易见,最后人一定会停留在某点。
dp1代表该点起始有车(从父亲方向来的车),并且访问完所有关键节点最后车和人都返回该点的最优解。
d2维护的是该点起始有车(从父亲方向来的车),访问完所有关键节点人返回,车不一定返回的最优解和次优解。并且记录最后车不返回的子树。
dp4代表该点起始无车(从父亲方向无车来),步行访问完所有关键节点的最优解。
总之最后总的思想就是,枚举最后人停留的点,不断找ans的最小值。

好像跟网上聚聚们写的思路不太一样...

坑:
sum=min(sum,sum-it->w+it->c-dp1[it->id]+dp2[it->id]);
*/













#include<bits/stdc++.h>
#define N 1000050
using namespace std;
long long ans;
struct st{
    st(){}
    st(int a,long long b){
        id=a;val=b;
    }
    int id;
    long long val;
};
bool cmp(st a,st b){
    if(a.val!=b.val)return a.val<b.val;
    return a.id<b.id;
}
vector<st>mv;
st d2[N][2];
long long inf=0x3f3f3f3f3f3f3f3f;
struct edge{
    int id;
    long long w,c;
    edge *next;
};
int ednum;
edge edges[N<<1];
edge *adj[N];
inline void addedge(int a,int b,long long c,long long w){
    edge *tmp=&edges[ednum++];
    tmp->id=b;
    tmp->w=w;
    tmp->c=c;
    tmp->next=adj[a];
    adj[a]=tmp;
}
int im[N],fa[N],siz[N];
long long dp1[N],dp2[N],dp4[N];
void dfs(int pos){
    siz[pos]=im[pos];
    for(edge *it=adj[pos];it;it=it->next){
        if(!fa[it->id]){
            fa[it->id]=pos;
            dfs(it->id);
            siz[pos]+=siz[it->id];
        }
    }
    if(siz[pos]==1&&im[pos]){
        dp1[pos]=dp2[pos]=dp4[pos]=0;
    }
    else if(siz[pos]){
        long long sum=0,sum1=0;
        for(edge *it=adj[pos];it;it=it->next){
            if(fa[it->id]!=pos)continue;
            if(siz[it->id]){
                sum+=min(dp1[it->id]+2*it->w,dp4[it->id]+2*it->c);
                sum1+=dp4[it->id]+2*it->c;
            }
        }
        dp1[pos]=sum;
        dp4[pos]=sum1;
        long long aaa=sum;
        for(edge *it=adj[pos];it;it=it->next){
            if(fa[it->id]!=pos)continue;
            if(siz[it->id]){
                if(dp1[it->id]+2*it->w < dp4[it->id]+2*it->c)
                    aaa=min(aaa,min(sum - it->w + it->c , sum - it->w - dp1[it->id] + dp2[it->id] + it->c));
                else
                    aaa=min(aaa,min(sum - it->c - dp4[it->id] + it->w + dp1[it->id],sum - it->c - dp4[it->id] + it->w +dp2[it->id]));
            }
        }
        dp2[pos]=aaa;
    }
}
void dfs2(int pos,long long ww,long long cc){
    mv.clear();
    if(pos==1){
        ans=min(ans,dp1[pos]);
        long long sum=dp1[pos];
        for(edge *it=adj[pos];it;it=it->next){
            sum=dp1[pos];
            if(fa[it->id]==pos&&siz[it->id]){
                if(dp1[it->id] + 2 * it->w < dp4[it->id] + 2*it->c )
                    sum=min(sum,dp1[pos] - it->w + it->c - dp1[it->id] + dp2[it->id] );
                else
                    sum=min(sum,dp1[pos]-it->c+it->w-dp4[it->id]+dp2[it->id]);
                mv.push_back(st(it->id,sum));
            }
        }
        mv.push_back(st(0,dp1[pos]));
        mv.push_back(st(0,dp1[pos]));
        sort(mv.begin(),mv.end(),cmp);
        d2[pos][0]=mv[0];
        d2[pos][1]=mv[1];
        ans=min(ans,d2[pos][0].val);
        ans=min(ans,dp4[pos]);
    }
    else{
        long long ss=dp1[fa[pos]];
        if(dp1[pos]+2*ww<dp4[pos]+2*cc)ss-=dp1[pos]+2*ww;
        else ss-=dp4[pos]+2*cc;
        ss+=cc;
        if(d2[fa[pos]][0].id!=pos){
            long long sss=d2[fa[pos]][0].val;
            if(dp1[pos]+2*ww<dp4[pos]+2*cc)sss-=dp1[pos]+2*ww;
            else sss-=dp4[pos]+2*cc;
            sss+=cc;
            ss=min(ss,sss);
        }
        else{
            long long sss=d2[fa[pos]][1].val;
            if(dp1[pos]+2*ww<dp4[pos]+2*cc)sss-=dp1[pos]+2*ww;
            else sss-=dp4[pos]+2*cc;
            sss+=cc;
            ss=min(ss,sss);
        }
        ss=min(ss,dp4[fa[pos]]-cc-dp4[pos]);
        long long bb=dp1[fa[pos]];
        if(dp1[pos]+2*ww<dp4[pos]+2*cc)bb-=dp1[pos]+2*ww;
        else bb-=dp4[pos]+2*cc;
        bb+=ww;
        dp1[pos]=bb;
        for(edge *it=adj[pos];it;it=it->next){
            if(fa[it->id]==pos&&siz[it->id]){
                dp1[pos]+=min(dp1[it->id]+2*it->w,dp4[it->id]+2*it->c);
            }
        }
        ans=min(ans,dp1[pos]);
        long long sum=dp1[pos];
        for(edge *it=adj[pos];it;it=it->next){
            sum=dp1[pos];
            if(fa[it->id]==pos&&siz[it->id]){
                if(dp1[it->id]+2*it->w<dp4[it->id]+2*it->c)
                    sum=min(sum,dp1[pos]-it->w+it->c-dp1[it->id]+dp2[it->id]);
                else
                    sum=min(sum,dp1[pos]-it->c+it->w-dp4[it->id]+dp2[it->id]);
                mv.push_back(st(it->id,sum));
            }
        }
        mv.push_back(st(0,dp1[pos]));
        mv.push_back(st(0,dp1[pos]));
        sort(mv.begin(),mv.end(),cmp);
        d2[pos][0]=mv[0];
        d2[pos][1]=mv[1];
        dp4[pos]+=ss;
        ans=min(ans,dp4[pos]);
        ans=min(ans,d2[pos][0].val);
    }
    for(edge *it=adj[pos];it;it=it->next){
        if(fa[it->id]==pos&&siz[it->id]){
            dfs2(it->id,it->w,it->c);
        }
    }
}
int main()
{
    fa[1]=-1;
    int n;
    scanf("%d",&n);
    int a,b;
    long long w,c;
    for(int i=1;i<=n;i++){
        dp1[i]=dp2[i]=dp4[i]=inf;
    }
    for(int i=1;i<n;i++){
        scanf("%d%d%lld%lld",&a,&b,&c,&w);
        addedge(a,b,c,w);
        addedge(b,a,c,w);
    }
    int m;
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        scanf("%d",&a);
        im[a]=1;
    }
    dfs(1);
    ans=inf;
    dfs2(1,0,0);
    printf("%lld
",ans);
}
原文地址:https://www.cnblogs.com/tun117/p/5928055.html