树形dp——2020-camp-day3-G

/*
以某个K点为根,建立一棵包含K个点的最小树,
处理出这棵树内每个点到最远点的距离dis[i]
 
处理出树外点到这棵树的最近的点pa[i] 
树内点:Sum-dis[i]
树外点:Sum-dis[pa[i]]+deep[i]-deep[pa[i]]
*/
#include<bits/stdc++.h>
#include<vector>
using namespace std;
#define ll long long 
#define N 500005

struct Edge{
    ll to,nxt,w;
}e[N<<1];
int head[N],tot,in[N];
ll n,K,flag[N],root,Sum;
void init(){
    memset(head,-1,sizeof head);
    tot=0;
}
void add(ll u,ll v,ll w){
    e[tot].to=v;e[tot].w=w;e[tot].nxt=head[u];head[u]=tot++;
}

int sizeK[N];
void getsize(int u,int pre){
    sizeK[u]=flag[u];
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre)continue;
        getsize(v,u);
        sizeK[u]+=sizeK[v];
    }
}

ll deep[N];
void getdeep(int u,int pre){
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre)continue;
        if(sizeK[v]!=0)Sum+=e[i].w*2;
        deep[v]=deep[u]+e[i].w;
        getdeep(v,u);
    }
}

ll len[N],dis[N];
void getlen(int u,int pre){
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre)continue;
        getlen(v,u);
        if(flag[v])
            len[u]=max(len[u],len[v]+e[i].w);
    }
}
void getdis(int u,int pre,ll up){
    if(flag[u]==0)return;
    ll mx1=0,id1,mx2=0,id2;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre||!flag[v])continue;
        if(len[v]+e[i].w>mx1){
            mx2=mx1,id2=id1;
            mx1=len[v]+e[i].w;id1=v;
        }
        else if(len[v]+e[i].w>mx2){
            mx2=len[v]+e[i].w,id2=v;
        }
    }
    
    if(up>mx1){
        mx2=mx1,id2=id1;
        mx1=up;id1=-1;
    }
    else if(up>mx2){
        mx2=up,id2=-1;
    }
    dis[u]=mx1;
    
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre)continue;
        if(id1==v)
            getdis(v,u,mx2+e[i].w);
        else 
            getdis(v,u,mx1+e[i].w);
    }
}

struct Node{
    int from,now;
    Node(){}
    Node(int from,int now):from(from),now(now){} 
};
int vis[N],pa[N];
ll ans[N];

int main(){
    cin>>n>>K;
    init();
    for(int i=1;i<n;i++){
        ll u,v,w;        
        scanf("%lld%lld%lld",&u,&v,&w);
        add(u,v,w);add(v,u,w);
        in[u]++;in[v]++;
    }
    
    for(int i=1;i<=K;i++){
        int x;scanf("%d",&x);
        flag[x]=1;root=x;
    }
    
    getsize(root,0);
    getdeep(root,0);
    for(int i=1;i<=n;i++){
        if(sizeK[i]==0)flag[i]=0;
        else flag[i]=1;
    }
    getlen(root,0);
    getdis(root,0,0);
    
    queue<Node>q;
    for(int i=1;i<=n;i++)
        if(flag[i]){
            q.push(Node(i,i));
            vis[i]=1;
            ans[i]=Sum-dis[i];
        }
    while(q.size()){
        Node cur=q.front();q.pop();
        int u=cur.now;
        for(int i=head[u];i!=-1;i=e[i].nxt){
            int v=e[i].to;
            if(vis[v])continue;
            q.push(Node(cur.from,v));
            vis[v]=1;
            ans[v]=deep[v]-deep[cur.from]+Sum-dis[cur.from];
        }
    }
    
    for(int i=1;i<=n;i++)cout<<ans[i]<<"
";
} 
原文地址:https://www.cnblogs.com/zsben991126/p/12197506.html