HDU6820 Tree (2020杭电多校第5场1007) 树形dp

题意

有一个n个点构成的树,每条边有一个边权d。求最多有一个点度数超过k的联通子图的边权和最大值。

分析

  1. 首先k=0时答案为0

  2. dp[0][u]代表以u为根的子树中所有点度数都小于等于k时的边权和最大值,且u与它的父节点有连边。dp[1][u]代表以u为根的子树中存在一个点的度数大于k时的边权和最大值,且u与它的父节点有连边。

  3. (dp[0][u]=max_{v_1,v_2,...,v_{k-1}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i])+d),其中vu的儿子节点,du的父亲节点到u的边权值

  4. (dp[1][u]=max(sum_vdp[0][v],max_{v_1,v_2,...,v_{k-1}}(sum_{i=v_1,v_2,...,v_{k-2}}dp[0][i]+dp[1][v_{k-1}]))+d)

  5. (ans=max_u(dp[1][u],max_{v_1,v_2,...,v_{k}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k])))

  6. 其中(max_{v_1,v_2,...,v_{k}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k]))可以对dp[0][v]由大到小排序,然后对前k个计算(sum_{i=1}^kdp[1][i]+dp[1][v]-dp[0][v]),对后cnt-k个计算(sum_{i=1}^kdp[1][i]+dp[1][k]-dp[0][v]),复杂度为O(nlogn)

代码

#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
typedef long long ll;
int n,k;

struct Node{int to,next;ll d;}edge[maxn*2];
int head[maxn],ecnt;
int cnt[maxn];
ll Ans,ans[2][maxn];
int son[maxn];
void init()
{
    memset(head,-1,sizeof(head[0])*(n+5));
    memset(cnt,0,sizeof(cnt[0])*(n+5));
    ecnt=0;
    Ans=0;
}
void addedge(int u,int v,ll d)
{
    edge[ecnt]={v,head[u],d};
    head[u]=ecnt++;
    edge[ecnt]={u,head[v],d};
    head[v]=ecnt++;
    cnt[u]++;cnt[v]++;
}

void dfs(int u,int fa,ll d)
{
    ans[1][u]=ans[0][u]=d;
    vector<ll>v0;
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v==fa)continue;
        dfs(v,u,edge[i].d);

        ans[1][u]+=ans[0][v];
        v0.push_back(ans[0][v]);
    }

    sort(v0.begin(),v0.end(),greater<ll>());
    for(int i=0;i<min((ll)v0.size(),(ll)k-1);i++)
        ans[0][u]+=v0[i];

    for(int i=head[u],j=1;i!=-1;i=edge[i].next)
        if(edge[i].to!=fa)
            son[j]=edge[i].to,j++;
    sort(son+1,son+cnt[u]+1,[](int i,int j){
        return ans[0][i]>ans[0][j];
    });
    //return
    int nn=min(k-1,cnt[u]);
    ll ans1=d;
    if(k>=2)
    {
        for(int i=1;i<=nn;i++)
            ans1+=ans[0][son[i]];
        for(int i=1;i<=nn;i++)
            ans[1][u]=max(ans[1][u],ans1-ans[0][son[i]]+ans[1][son[i]]);
        ans1-=ans[0][son[nn]];
        for(int i=nn+1;i<=cnt[u];i++)
            ans[1][u]=max(ans[1][u],ans1+ans[1][son[i]]);
    }
    else
        ans[1][u]=max(ans[1][u],ans1);
    //dp
    nn=min(k,cnt[u]);
    ll ans2=0;
    for(int i=1;i<=nn;i++)
        ans2+=ans[0][son[i]];
    for(int i=1;i<=nn;i++)
        Ans=max(Ans,ans2-ans[0][son[i]]+ans[1][son[i]]);
    ans2-=ans[0][son[nn]];
    for(int i=nn+1;i<=cnt[u];i++)
        Ans=max(Ans,ans2+ans[1][son[i]]);
    
    Ans=max(Ans,ans[1][u]);
}

int main()
{
    int t;
    scanf("%d",&t);
    while(t--)
    {
        int u,v;ll d;
        scanf("%d%d",&n,&k);
        init();
        for(int i=1;i<=n-1;i++)
        {
            scanf("%d%d%lld",&u,&v,&d);
            addedge(u,v,d);
        }
        if(k==0)
        {
            printf("0
");
            continue;
        }
        for(int i=2;i<=n;i++)
            cnt[i]--;
        dfs(1,0,0);
        printf("%lld
",Ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/intmian/p/13435225.html