NOIP2015 d2T3 二分+树上前缀和

NOIP2015 day2 T3

题目大意:给出一棵树以及若干点对,现要求使树中某一条边的权值变为0,使得最远点对的距离最小。

sol:
​ 用到了很多的思想、技巧和常用结论。
二分答案(x)(x)为最远点对的距离。
​ 由此将问题转化为关于(x)单调的判定性问题(C(x)):是否可以使某一条边的权值变为0,使得所有点对的距离都不超过(x)
​ 我们在二分之前倍增预处理出所有点对之间的距离。对于本来就不超过(x)的,不用管;对于原本超过(x)的,我们就必须对这条路径上的某条边进行操作了。
​ 是哪条边呢?贪心一下不难得出,应该是所有超过(x)的路径的交集里面最长的那一条边。
​ 如何求出它们的交集呢?某UO钩博客上给出了两种解法:

​ 1,模拟。把路径求交集的正确姿势当然是两两交起来:显然,树上路径的交集仍然是连续的路径
如何求两两路径的交集呢?设这两个点对是((a,b))((c,d))
那么交集路径的两个顶点只可能是:(LCA(a,b),LCA(a,c),LCA(a,d),LCA(b,c),LCA(b,d),LCA(c,d))
这样,一共就有15种可能。枚举之,找出最大的合法交集即可。

​ 2,树上前缀和。一种很朴素的想法,就是统计每条边的被覆盖的次数,覆盖次数等于路径集大小的边的集合就是交集。但我们不可能一条一条边地打标记。我们需要一种高效的维护这个信息的方法。
​ 我们回顾一下NOIP 2012借教室。一个序列,给出若干组((l,r,a)​)([l,r]​)区间里的数都减(a​)(a​)为正整数。问按输入顺序操作这些区间,最多可以操作多少,使得序列中不存在负数。答案显然是单调的,即操作的区间越多,约束的条件(不存在负数)越难满足。所以,我们也可以二分答案,转化为判定性问题。如何高效地判定操作完前(i​)个区间是否满足约束呢?奇妙的技巧是维护(a​)的“前缀和”。黄学长这样描述:比如一开始数列(a​)({0,0,0,0,0,0}​),给出区间({3,5,2}​),则将(a_3​)+=(2​)(a_6​)-=(2​),就变成({0,0,2,0,0,-2}​),前缀和变成({0,0,2,2,2,0}​),把前缀和当成元素,这样就实现了对这个区间的操作。降维思想在这里巧妙地运用了。这样,就可以在(O(n)​)的时间复杂度内完成判定,总的复杂度就是(O(nlog n)​)。借教室就是相当于“统计序列中每个元素被减的次数(即对它有影响的(a​)的和)”,所以,原题中就相当于把借教室的方法作树上的拓展。
​ 做法如下:对于每个点维护一个间接值(v),我们对这个(v)直接进行加减操作,再维护一个类似于关于(v)的“前缀和”的东西(在树中,更准确的说法应该是“子树和”),就能维护需要的信息了。对于每个点对((a,b)),我们把(v_a)++,(v_b)++,(v_{LCA(a,b)})-=2。然后我们对每个点再弄一个(s),表示这个点到它的父亲这条边被覆盖的次数。则(s_i = v_i + Sigma_{k in son_i}{s_k})。这样,我们就可以按照DFS序维护出(s)了。这是一个非常棒的技巧,需要记忆。

第一个做法带两个log,第二个做法带一个log。据说还有不带log的线性做法,但leader武爷爷不屑于研究,这个做法也就石沉大海无人知晓了(其实另一篇UO钩博客上有啊。。

另外注意这个二分由于单调性是递增的,所以按我的惯用写法应该输出右界。
第一份代码码错了两个地方。详见注释。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

 #define rep(i,a,b) for (int i = a; i <= b; i++)
 #define dep(i,a,b) for (int i = a; i >= b; i--)
 #define fill(a,x) memset(a, x, sizeof(a))

 const int N = 300000 + 5, M = N*2, D = 22;

 int n, m, up, cnt = 0, a[M], b[M], fa[N][D], sum[N][D], len[M], v[N], s[N], pre[N], dep[N], dfsid[N];

 struct Edge { int to, w, pre; } e[M];
 void ine(int a, int b, int w) {
 	cnt++;
 	e[cnt].to = b, e[cnt].w = w, e[cnt].pre = pre[a];
 	pre[a] = cnt;
 }
 void ine2(int a, int b, int w) {
 	ine(a, b, w);
 	ine(b, a, w);
 }
 #define reg(i,x) for (int i = pre[x]; i; i = e[i].pre)

 int read() {
    int x = 0; char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') {
    	x = x*10 + ch - '0';
    	ch = getchar();
    }
    return x;
 }

 int cur = 0;
 void change(int x, int dept) {
 	dep[x] = dept;
 	reg(i,x) {
 		int y = e[i].to;
 		if (y == fa[x][0]) continue;
 		fa[y][0] = x;
 		sum[y][0] = e[i].w;
 		change(y, dept + 1);
 	}
 	dfsid[++cur] = x;
 }

 void init() {
 	int tmp = 1; up = 0;
 	while ((1<<up) <= n) up++;
 	rep(j,1,up) rep(i,1,n) {  // 一开始把ij循环写反了… 显然不能搞反啊
 		int dad = fa[i][j-1];
 		fa[i][j] = fa[dad][j-1];
 		sum[i][j] = sum[i][j-1] + sum[dad][j-1];
 	}
 }

 int lca, tot;
 void query(int a, int b) {
 	tot = 0;
 	if (dep[a] < dep[b]) swap(a, b);
 	int del = dep[a] - dep[b];
 	rep(i,0,up) if ((1<<i)&del) tot += sum[a][i], a = fa[a][i];
 	if (a == b) { lca = a; return; }
 	dep(i,up,0) if (fa[a][i] != fa[b][i]) {
 		tot += (sum[a][i] + sum[b][i]);
 		a = fa[a][i], b = fa[b][i];
 	}
 	if (a != b) { lca = fa[a][0]; tot += (sum[a][0] + sum[b][0]); } else lca = a;
 	return;
 }

 int LCA(int a, int b) {
 	query(a, b);
 	return lca;
 }

 int get_len(int a, int b) {
 	query(a, b);
 	return tot;
 }

 bool judge(int x) {
 	fill(s, 0);
 	int over = 0;
 	rep(i,1,m) if (len[i] > x) {
 		s[a[i]]++; s[b[i]]++;
 		s[LCA(a[i],b[i])] -= 2;
 		over++;
 	}
 	rep(i,1,n) { int t = dfsid[i]; s[fa[t][0]] += s[t]; } // 之前写成了+=v[t]… 前!缀!和!勿理智障如我
 	int dec = 0;
 	rep(i,1,n) if (s[i] == over) dec = max(dec, sum[i][0]);
 	rep(i,1,m) if (len[i] - dec > x) return false;
 	return true;
 }


int main()
{
	n = read(), m = read();
	int u, v, w;
	fill(pre, 0);
	rep(i,1,2*n) e[i].pre = 0;
	rep(i,1,n-1) {
		u = read(), v = read(), w = read();
		ine2(u, v, w);
	}

    fa[1][0] = 0; sum[1][0] = 0;
	change(1, 0);
	init();

    int maxl = 0;
	rep(i,1,m) {
		a[i] = read(), b[i] = read();
		len[i] = get_len(a[i], b[i]);
		maxl = max(maxl, len[i]);
	}

	int l = -1, r = maxl + 1;
	while (l + 1 < r) {
		int mid = (l + r)>>1;
		if (judge(mid)) r = mid; else l = mid;
	}

	printf("%d
", r);

	return 0;
}

原文地址:https://www.cnblogs.com/yearwhk/p/5882921.html