[JZOJ5455]【NOIP2017提高A组冲刺11.6】拆网线

Description

企鹅国的网吧们之间由网线互相连接,形成一棵树的结构。现在由于冬天到了,供暖部门缺少燃料,于是他们决定去拆一些网线来做燃料。但是现在有K只企鹅要上网和别人联机游戏,所以他们需要把这K只企鹅安排到不同的机房(两只企鹅在同一个机房会吵架),然后拆掉一些网线,但是需要保证每只企鹅至少还能通过留下来的网线和至少另一只企鹅联机游戏。
所以他们想知道,最少需要保留多少根网线?
 

Input

第一行一个整数T,表示数据组数;
每组数据第一行两个整数N,K,表示总共的机房数目和企鹅数目。
第二行N-1个整数,第i个整数Ai表示机房i+1和机房Ai有一根网线连接(1≤Ai≤i)。

Output

每组数据输出一个整数表示最少保留的网线数目。
 

Sample Input

2
4 4
1 2 3
4 3
1 1 1

Sample Output

2
2
 

Data Constraint

对于30%的数据:N≤15;
对于50%的数据:N≤300;
对于70%的数据:N≤2000;
对于100%的数据:2≤K≤N≤100000,T≤10。

最优的情况是一条边正好站两个企鹅,这样会使得保留下来的边最少。

我们怎么求呢?

设ans为在这棵树中满足一条边被两个点站的点对的个数*2, 即点数。

那么如果ans >= k,直接输出(k+1)/2.

如果ans < k, 那么企鹅的站位一定有出现菊花图的样子,我们至多可以满足ans个点找到自己的匹配,剩下的(k-ans)个企鹅只能和别的企鹅链接形成菊花图的样子。

这样它自己站一条边, 所以输出ans / 2 + (k - ans)。

接下来的问题是如何找出ans。

考虑树形DP, 设f[i][0/1]为i的子树中,i这个节点选/不选的最大的两个点相互匹配的点数。

那么显然有f[u][0] += f[v][1];

f[u][1] = max(f[u][0] - f[v][1] + f[v][0] + 2),

解释一下:把式子变一下 $large f[u][1]=(sum f[v'][1])-f[v][1]+f[v][0]+2$

因为这个点和枚举的v形成了一个匹配, 所以+2.

写起来不是很难。


#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read() {
    int res=0;char c=getchar();bool f=0;
    while(!isdigit(c)) {if(c=='-')f=1;c=getchar();}
    while(isdigit(c))res=(res<<3)+(res<<1)+(c^48),c=getchar();
    return f?-res:res;
}
int T, n, k;
struct edge {
    int nxt, to;
}ed[200005];
int head[100005], cnt;
inline void add(int x, int y) {
    ed[++cnt] = (edge){head[x], y};
    head[x] = cnt;
}
int f[100005][2];

void dfs(int x, int fa)
{
    for (int i = head[x] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (to == fa) continue;
        dfs(to, x);
        f[x][0] += f[to][1];
    }
    for (int i = head[x] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (to == fa) continue;
        f[x][1] = max(f[x][1], f[x][0] - f[to][1] + f[to][0] + 2);
    }
}

int main()
{
    freopen("tree.in", "r", stdin);
    freopen("tree.out", "w", stdout);
    T = read();
    while(T--)
    {
        cnt = 0;
        memset(head, 0, sizeof head);
        memset(f, 0, sizeof f);
        n = read(), k = read();
        for (int i = 1 ; i <= n - 1 ; i ++)
        {
            int x = read();
            add(x, i + 1), add(i + 1, x);
        }
        dfs(1, 0);
        int ans = max(f[1][1], f[1][0]);
        if (ans >= k) printf("%d
", (k + 1) / 2);
        else printf("%d
", ans / 2 + (k - ans));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/BriMon/p/9439578.html