关于wqs二分(凸优化) 实数二分和整数二分的一些讨论

https://loj.ac/problem/2478

以上面这题为例,这道题斜率是不递增的,并且都是整数。

实数二分很爽,但是效率不高,对于斜率都是整数的,我们可以采用整数二分,但是需要注意一点细节:

wqs二分,是找一个斜率,使得第k个成为最优点。

但是,因为斜率可能出现一段相同的情况,因此可能有一个斜率,使得一干点同时成为最优,我们要第k个被包含在其中,这样就没有办法恰好找到了。

此时,在dp的过程,应该保证权值和相同时,选的点数尽可能多,这样找到一个斜率ans使得最优点恰好(>=k),真正答案就是(dp值+ans*k)

“权值和相同时,选的点数尽可能多”是必要的,因为假设k在一段斜率相同的中间,而dp出的点数总是在k的左边,这样就会取到下一个斜率,就不是最优了。

参考代码:

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("
")
using namespace std;

#define V vector<int>
#define pb push_back
#define si size()

const int N = 3e5 + 5;

int n, k, x, y, z;
int fi[N * 2], nt[N * 2], to[N * 2], v[N * 2], tot;

void link(int x, int y, int z) {
	nt[++ tot] = fi[x], to[tot] = y, v[tot] = z, fi[x] = tot;
}

#define pii pair<ll, int>
#define fs first
#define se second

int fa[N];
pii f[N][3], g[N][3];

ll w;

pii mv(pii a, pii b) {
	if(a.fs == b.fs) return a.se > b.se ? a : b;
	return a.fs > b.fs ? a : b;
}

pii operator + (pii a, pii b) {
	return pii(a.fs + b.fs, a.se + b.se);
}

void dg(int x) {
	f[x][0] = w < 0 ? pii(0, 0) : pii(w, 1);
	f[x][1] = pii(w, 1);
	f[x][2] = pii(w, 1);
	
	for(int i = fi[x]; i; i = nt[i]) {
		int y = to[i], z = v[i];
		if(y == fa[x]) continue;
		fa[y] = x;
		dg(y);
		
		g[x][2] = f[x][2] + f[y][0];
		
		g[x][1] = f[x][1] + f[y][0];
		pii d = f[x][2] + f[y][1];
		d.se --; d.fs += z - w;
		g[x][1] = mv(g[x][1], d);
		
		g[x][0] = f[x][0] + f[y][0];
		d = f[x][1] + f[y][1];
		d.se --; d.fs += z - w;
		g[x][0] = mv(g[x][0], d);
		
		f[x][0] = mv(g[x][0], mv(g[x][1], g[x][2]));
		f[x][1] = mv(g[x][1], g[x][2]);
		f[x][2] = g[x][2];
	}
}

int count(ll _w) {
	w = _w;
	dg(1);
	return f[1][0].se;
}

int main() {
	scanf("%d %d", &n, &k); k ++;
	fo(i, 1, n - 1) {
		scanf("%d %d %d", &x, &y, &z);
		link(x, y, z); link(y, x, z);
	}
	ll ans = 0;
	for(ll l = -3e11, r = 3e11; l <= r; ) {
		ll m = l + r >> 1;
		if(count(m) >= k) ans = m, r = m - 1; else l = m + 1;
	}
	count(ans);
	pp("%lld
", f[1][0].fs - k * ans);
}
原文地址:https://www.cnblogs.com/coldchair/p/12609384.html