Codeforces 1179D 树形DP 斜率优化

题意:给你一颗树,你可以在树上添加一条边,问添加一条边之后的简单路径最多有多少条?简单路径是指路径中的点只没有重复。

思路:添加一条边之后,树变成了基环树。容易发现,以基环上的点为根的子树的点中的简单路径没有增加。所以,问题相当于转化为找一个基环,使得以基环上的点为根的子树Σ(i从1到n) sz[i]  * (sz[i] - 1) / 2最小。我们把式子转化一下变成求(sz[i]的平方和 - n) / 2。相当于我们需要求sz[i]的平方和。但是,我们并不知道哪个是基环,怎么求sz呢?我们发现一个性质:添加的边连接的两点一定是树中度数为1的点,否则,我们一定可以缩小平方和。所以,根据这个性质,我们可以进行树形dp。设dp[i]为以i为根的子树中,选择从i到子树中的某个叶子节点的路径为基环上的点,可以获得的最小的平方和。dp[i] = min(dp[son] + (sz[i] - sz[son]) ^ 2)。

我们假设选择的基环是u -> lca(u, v) -> v ,假设fu为u到lca(u, v)的路径中lca(u, v)的前面一个节点,fv同理,那么平方和为ans = dp[fu] + dp[fv] + (n - sz[fu] - sz[fv]) ^ 2。所以,我们在深搜的时候,找到所有孩子的dp值和sz,枚举是哪两个孩子来更新平方和,这样最坏情况是O(n ^ 2)的,会超时。发现状态转移方程中有fu和fv的乘积项,我们可以考虑斜率优化。把方程移项: dp[fv] = 2 * (n - sz[fu]) * sz[fv] + (ans - dp[fu] - 2 * n * sz[fu])。那么相当于是以sz[fv]为横坐标,dp[fv]为纵坐标,斜率为2 * (n - sz[fu])的直线,要ans最小,需要截距最小。我们把sz从小到大排序,用单调队列维护一个下凸包,之后在单调队列里二分即可。注意的细节:1,二分之后需要判断合不合法,不能fu和fv相等了。2:斜率优化只考虑的fu和fv不等的情况,我们需要特判一下从最优的叶子结点直接连到当前结点的这种情况。

代码:

#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define LL long long
#define pll pair<LL, LL>
using namespace std;
const int maxn = 500010;
vector<int> G[maxn];
LL sz[maxn], dp[maxn];
pll q[maxn], a[maxn], b[maxn];
int l, r;
LL n, ans;
int tot;
map<pll, int> mp;
void add(int x, int y) {
	G[x].push_back(y);
	G[y].push_back(x);
}
bool check(pll x, pll y, pll z) {
	if((y.second - x.second) * (z.first - y.first) < (y.first - x.first) * (z.second - y.second))
		return 1;
	else
		return 0;
}
int binary_search(pll x, LL k) {
	if(l == r) return l;
	int L = l, R = r;
	while(L < R) {
		int mid = (L + R) >> 1;
		if((q[mid + 1].second - q[mid].second) <= k * (q[mid + 1].first - q[mid].first)) L = mid + 1;
		else R = mid; 
	}
	return L;
}
void dfs(int x, int fa) {
	sz[x] = 1;
	for (auto y : G[x]) {
		if(y == fa) continue;
		dfs(y, x);
		sz[x] += sz[y];
	}
	for (auto y : G[x]) {
		if(y == fa) continue;
		dp[x] = min(dp[x], (sz[x] - sz[y]) * (sz[x] - sz[y]) + dp[y]);
	}
	tot = 0;
	for (auto y : G[x]) {
		if(y == fa) continue;
		b[++tot] = a[y];
	}
	sort(b + 1, b + 1 + tot);
	mp.clear();
	if(tot > 1) {
		l = 1, r = 0;
		for (int i = 1; i <= tot; i++) {
			mp[b[i]]++;
			while(l < r && !check(q[r - 1], q[r], b[i])) r--;
			q[++r] = b[i];
		}
		for (int i = 1; i <= tot; i++) {
			int pos = binary_search(b[i], 2 * (n - b[i].first));
			if(q[pos] == b[i]) {
				if(mp[b[i]] > 1) {
					ans = min(ans, b[i].second + q[pos].second + (n - b[i].first - q[pos].first) * (n - b[i].first - q[pos].first));
				} else {
					if(pos < r) ans = min(ans, b[i].second + q[pos + 1].second + (n - b[i].first - q[pos + 1].first) * (n - b[i].first - q[pos + 1].first));
					else ans = min(ans, b[i].second + q[pos - 1].second + (n - b[i].first - q[pos - 1].first) * (n - b[i].first - q[pos - 1].first));
				}
			} else {
				ans = min(ans, b[i].second + q[pos].second + (n - b[i].first - q[pos].first) * (n - b[i].first - q[pos].first)); 
			}
		}
	}
	for (int i = 1; i <= tot; i++) {
		ans = min(ans, b[i].second + (n - b[i].first) * (n - b[i].first));
	}
	if(fa != -1 && G[x].size() == 1) dp[x] = sz[x] * sz[x];
	a[x] = make_pair(sz[x], dp[x]);
}
int main() {
	int x, y;
	memset(dp, 0x3f, sizeof(dp));
//	freopen("1179Din.txt", "r", stdin);
//	freopen("1179D1out.txt", "w", stdout);
	scanf("%lld", &n);
	for (int i = 1; i < n; i++) {
		scanf("%d%d", &x, &y);
		add(x, y);
	}
	ans = 1e18;
	dfs(1, -1);
	ans = min(ans, dp[1]);
	ans -= n;
	ans /= 2;
	ans = 2ll * n * (n - 1) / 2ll - ans;
	printf("%lld
", ans);
} 

  

原文地址:https://www.cnblogs.com/pkgunboat/p/11142535.html