BZOJ1827 [Usaco2010 Mar]gather 奶牛大集会

题意:给定一棵树,求出树上的一点,使得树上的全部点到该点的距离之和最小。


思路:暴力显然是O(N^2)等死对吧。

我们首先将无根树转化为有根树,然后一边dfs求出f[i],size[i].

f[i]表示以i为根的子树中全部的点到i的距离之和,size[i]表示以i为根的子树的点数。


以下開始脑洞大开:

如今对于我们一開始的那个root,我们已经知道了答案。问题就是怎样高速的推知别的点作为根时的答案。

我们又一次进行一次dfs,当找到x时,我们用dp[fa[x]]+padis[x]*size[fa[x]]更新答案。

我们记录一下当前的dp[x],以及size[x].

每找到一个儿子son,向下dfs时,我们令dp[x]=dp[fa[x]]+size[fa[x]]*padis[x]+dp[x]-dp[son]-size[son]*padis[son],size[x]=size[fa[x]]+size[x]-size[son],然后再向下dfs.

不要问我为什么。。。


我的代码用的是更加脑洞大开的方法。。。


Code:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

#define N 100010
int head[N], next[N << 1], end[N << 1], len[N << 1];
void addedge(int a, int b, int _len) {
	static int q = 1;
	len[q] = _len;
	end[q] = b;
	next[q] = head[a];
	head[a] = q++;
}

int num[N];

long long dp[N];
int pa[N], padis[N], size[N];
void dfs(int x, int fa) {
	size[x] = num[x];
	for(int j = head[x]; j; j = next[j])
		if (end[j] != fa)
			pa[end[j]] = x, padis[end[j]] = len[j], dfs(end[j], x);
	for(int j = head[x]; j; j = next[j])
		if (end[j] != fa)
			dp[x] += dp[end[j]] + (long long)size[end[j]] * len[j], size[x] += size[end[j]];
}

long long res = 1LL << 60;
int presize[N], sufsize[N], addsize[N], sav[N], top;
long long pre[N], suf[N], add[N];
void work(int x) {
	long long ans = add[x] + (long long)addsize[x] * padis[x] + dp[x];
	if (ans < res)
		res = ans;
		
	register int i, j;
	top = 0;
	for(j = head[x]; j; j = next[j])
		if (end[j] != pa[x])
			sav[++top] = end[j];
	
	presize[0] = pre[0] = 0, sufsize[top + 1] = suf[top + 1] = 0;
	for(i = 1; i <= top; ++i)
		presize[i] = presize[i - 1] + size[sav[i]], pre[i] = pre[i - 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]];
	for(i = top; i >= 1; --i)
		sufsize[i] = sufsize[i + 1] + size[sav[i]], suf[i] = suf[i + 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]];
	for(i = 1; i <= top; ++i) {
		addsize[sav[i]] = addsize[x] + num[x] + presize[i - 1] + sufsize[i + 1];
		add[sav[i]] = add[x] + (long long)addsize[x] * padis[x] + pre[i - 1] + suf[i + 1];
	}
	
	for(j = head[x]; j; j = next[j])
		if (end[j] != pa[x])
			work(end[j]);
}
int main() {
	int n;
	scanf("%d", &n);
	
	register int i, j;
	for(i = 1; i <= n; ++i)
		scanf("%d", &num[i]);
	
	int a, b, x;
	for(i = 1; i < n; ++i) {
		scanf("%d%d%d", &a, &b, &x);
		addedge(a, b, x);
		addedge(b, a, x);
	}
	
	dfs(1, -1);
	work(1);
	
	printf("%lld", res);
	
	return 0;
}


原文地址:https://www.cnblogs.com/lcchuguo/p/4488964.html