bzoj4033 [HAOI2015]树上染色(树形)

Description
有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并
将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。
问收益最大值是多少。

Input
第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。
N<=2000,0<=K<=N

Output
输出一个正整数,表示收益的最大值。

Sample Input
5 2
1 2 3
1 5 1
2 3 1
2 4 2

Sample Output
17

【样例解释】
将点1,2染黑就能获得最大收益。

分析:
发现自己完全理解不了概率dp
所以就跑来玩树了

设计状态:
f[i][j]表示i这棵子树中,选择j个黑点的最大价值

(想要直接看正解的同学请跳过这一段)
我一开始以为这道题的dp思路有点像ta
枚举每一个子树中的黑点个数:
f[i][j]=max{f[i][j-k]+f[son][k]+新贡献}
这里写图片描述
但是这个新贡献要怎么计算呢
我觉得可以记录一个值g[i][j]表示
在f[i][j]的前提下,i这棵子树中,所有黑点到达根的距离之和
这样我们就可以计算了:
g[i][k]*j+g[son][j]*k

这样转是能转移了,复杂度是n^3勉勉强强,
但是g的转移呢,好像还要枚举now的颜色
这样的话,f和g就要再开一维,而且我还忘了白点的贡献
哇,真tomato刺激,真tomato难码,真tomato会wa

好吧,看来我一开始想的这个方法可行性不高
之好再想一个:

>>>——————————————————————————————

状态不变
f[i][j]表示i这棵子树中,选择j个黑点的最大价值
我们直接考虑每条边的贡献
设当前边x
在x两边的黑点和白点在互相连通的时候,就一定会经过x
那么x的贡献就是
x*(x子树内的黑点个数×子树外的黑点个数+子树内白点的个数×子树外白点的个数)

枚举黑点的数量,显然有转移方程:
f[now][j]=max{f[now][j-k]+f[son][k]+当前边的贡献}

tip

开ll
子树内的黑点个数k是我们枚举的
在计算子树外的黑点和白点个数时,是默认一共有K个黑点
所以
子树外黑点个数:K-k
子树外白点个数:n-size[son]-(K-k)

交上去1000ms的时候WA了
把样程交上去,TLE。。。好像是因为9月新添的数据比较刁钻
大概是全是负边的情况

只要初始值是0,就是wa
初始值是-INF,就是TLE
不知道是什么样的数据能卡成这样
太刁钻了

这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long

using namespace std;

const ll INF=1e15;
const int N=2002;
ll f[N][N];
int n,st[N],tot=0,K,size[N];
struct node{
    int x,y,v,nxt;
};
node way[N<<1];

void add(int u,int w,int z)
{
    tot++;
    way[tot].x=u;way[tot].y=w;way[tot].nxt=st[u];way[tot].v=z;st[u]=tot;
    tot++;
    way[tot].x=w;way[tot].y=u;way[tot].nxt=st[w];way[tot].v=z;st[w]=tot;
}

void dfs(int now,int fa)
{
    size[now]=1;
    for (int i=st[now];i;i=way[i].nxt)
        if (way[i].y!=fa)
        {
            dfs(way[i].y,now);
            size[now]+=size[way[i].y];
        }
}

void doit(int now,int fa)
{
    for (int i=st[now];i;i=way[i].nxt)
        if (way[i].y!=fa)
        {
            doit(way[i].y,now);
            int y=way[i].y;
            for (int j=min(size[now],K);j>=0;j--)
                for (int k=0;k<=min(j,size[y]);k++)
                {
                    ll f1=(ll)way[i].v*(k*(K-k));
                    ll f2=(ll)way[i].v*(size[y]-k)*(n-size[y]-(K-k));
                    f[now][j]=max(f[now][j],f[now][j-k]+f[y][k]+f1+f2);
                }
        }
}

int main()
{
    scanf("%d%d",&n,&K);
    for (int i=1;i<n;i++)
    {
        int u,w,z;
        scanf("%d%d%d",&u,&w,&z);
        add(u,w,z);
    }
    dfs(1,0);
    doit(1,0);
    printf("%lld",f[1][K]);
    return 0;
}
原文地址:https://www.cnblogs.com/wutongtong3117/p/7673133.html