[Luogu] P5021 赛道修建

(Link)

Description

(C)城将要举办一系列的赛车比赛。在比赛前,需要在城内修建(m)条赛道。

(C)城一共有(n)个路口,这些路口编号为(1,2,…,n),有(n−1)条适合于修建赛道的双向通行的道路,每条道路连接着两个路口。其中,第(i)条道路连接的两个路口编号为(a_i)(b_i)​,该道路的长度为(l_i)​。借助这(n-1)条道路,从任何一个路口出发都能到达其他所有的路口。

一条赛道是一组互不相同的道路(e_1,e_2,…,e_k),满足可以从某个路口出发,依次经过道路(e_1,e_2,…,e_k)​(每条道路经过一次,不允许调头)到达另一个路口。一条赛道的长度等于经过的各道路的长度之和。为保证安全,要求每条道路至多被一条赛道经过。

目前赛道修建的方案尚未确定。你的任务是设计一种赛道修建的方案,使得修建的(m)条赛道中长度最小的赛道长度最大(即(m)条赛道中最短赛道的长度尽可能大)

Solution

首先看到最小的最大,就想到要二分。我们二分(m)条赛道中长度最小的赛道长度(mid),那么(check)的就是长度(ge{mid})的赛道最多能不能(ge{m}),如果可以,那么(l=mid+1),否则(r=mid-1)

这个还是可以想到的,就是(check)不太好写。注意到对于某个节点(x)和它的一个子节点(y),一定是选若干条以(y)为链顶的链,将剩下的最短的链(l_1)和对应最短的满足(len_1+len_2ge{mid})的链(l_2)拼成一条赛道,然后如果最后还剩下一条链(l_3),长度就一定是最大的,将它和(x ightarrow{y})这条边拼在一起,构成一条赛道。

这其实就是贪心的思想。因为(x ightarrow{y})这条边一定能且只能和一条以(y)为顶的链构成赛道,然后如果把某条(l_1)对应的(l_2)替换成更大的,可能反而会找不到一条链(l_3),满足(len_3+len_{x ightarrow{y}}ge{mid})

具体实现用(multiset)

Code

#include <bits/stdc++.h>

using namespace std;

int n, m, tot, res, sum, l = 1e9, r, hd[50005], to[100005], nxt[100005], w[100005];

multiset < int > g[50005];

int read()
{
	int x = 0, fl = 1; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
	return x * fl;
}

void add(int x, int y, int z)
{
	tot ++ ;
	to[tot] = y;
	nxt[tot] = hd[x];
	w[tot] = z;
	hd[x] = tot;
	return;
}

int dfs(int x, int fa, int d)
{
	g[x].clear();
	for (int i = hd[x]; i; i = nxt[i])
	{
		int y = to[i], z = w[i];
		if (y == fa) continue;
		int now = dfs(y, x, d) + z;
		if (now >= d) sum ++ ;
		else g[x].insert(now);
	}
	int mx = 0;
	while (g[x].size())
	{
		int cnt = *g[x].begin();
		if (g[x].size() == 1) return max(mx, cnt); 
		multiset < int > :: iterator it = g[x].lower_bound(d - cnt);
		if (it == g[x].begin()) it ++ ;
		g[x].erase(g[x].begin());
		if (it == g[x].end()) mx = max(mx, cnt);
		else sum ++ , g[x].erase(it);
	}
	return mx;
}

int check(int x)
{
	sum = 0;
	dfs(1, 0, x);
	return (sum >= m);
}

int main()
{
	n = read(); m = read();
	for (int i = 1; i <= n - 1; i ++ )
	{
		int x = read(), y = read(), z = read();
		add(x, y, z); add(y, x, z);
		l = min(l, z); r += z;
	}
	while (l <= r)
	{
		int mid = (l + r) >> 1;
		if (check(mid)) l = mid + 1, res = mid;
		else r = mid - 1;
	}
	printf("%d
", res);
	return 0;
}
原文地址:https://www.cnblogs.com/andysj/p/13956109.html