JZOJ 6866. 【2020.11.16提高组模拟】路径大小差(点分治+树状数组)

JZOJ 6866. 【2020.11.16提高组模拟】路径大小差

题目大意

  • 问树上有多少点对之间路径边权 m a x − m i n = k max-min=k maxmin=k k k k为定值。
  • k ≤ n ≤ 2 ∗ 1 0 5 kleq nleq2*10^5 kn2105.

题解

  • 其实这题比较套路,并不难想。
  • 关于树上路径计数的问题,一般先考虑点分治能不能实现,发现是可以的。
  • 按照一般点分治的套路,找到某个子树重心后,记录每个点到它的路径边权 m a x , m i n max,min max,min,有两种情况,一种是重心为路径的一端,直接枚举判断;另一种是重心在路径中间。
  • 第二种情况,按 m a x max max从小到大排序,枚举一条路径和前面的另一条组合,
  • 因为已经排好序了,所以 m a x max max一定在当前这条路径上,接着再分两种情况,一种是该路径的 m a x − m i n < k max-min<k maxmin<k,那么查找前面 m i n = m a x − k min=max-k min=maxk的数量加入答案;一种是该路径的 m a x − m i n = k max-min=k maxmin=k,则查找前面 m i n ≥ m a x − k mingeq max-k minmaxk的数量加入答案。用树状数组维护。
  • 但是会发现组合的两条路径可能出现在当前根的同一子树中,那么把每棵子树的路径单独求一遍,从答案中减去即可。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define N 200010
int n, K;
ll ans = 0;
int last[N], nxt[N * 2], to[N * 2], we[N * 2], len = 0;
int vi[N], si[N], sum[N], s, rt, mi;
int tot = 0, f[N];
struct node {
	int mx, mi, r;
}a[N];
void add(int x, int y, int w) {
	to[++len] = y;
	we[len] = w;
	nxt[len] = last[x];
	last[x] = len;
}
void dfs(int k, int fa) {
	si[k] = 1;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
		dfs(to[i], k);
		si[k] += si[to[i]];
	}
}
void find(int k, int fa) {
	int mx = s - si[k];
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
		find(to[i], k);
		mx = max(mx, si[to[i]]);
	}
	if(mx < mi) mi = mx, rt = k;
}
void dfs1(int k, int fa, int t0, int t1, int r) {
	if(t1) a[++tot].mx = t1, a[tot].mi = t0, a[tot].r = r;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
		dfs1(to[i], k, min(t0, we[i]), max(t1, we[i]), r == 0 ? to[i] : r);
	}
}
int cmp(node x, node y) {
	if(x.mx == y.mx) return x.mi < y.mi;
	return x.mx < y.mx;
}
int cmp1(node x, node y) {
	return x.r < y.r;
}
int low(int x) {
	return x & (-x);
}
void ins(int k, int c) {
	for(int i = k; i <= n; i += low(i)) f[i] += c;
}
int ct(int k) {
	int s = 0;
	for(int i = k; i; i -= low(i)) s += f[i];
	return s;
}
void ds(int l, int r, int o) {
	sort(a + l, a + r + 1, cmp);
	for(int i = l; i <= r; i++) {
		if(a[i].mx - a[i].mi == K) {
			ans += (i - l - ct(a[i].mi - 1)) * o;
		}
		else if(a[i].mx - a[i].mi < K) ans += sum[a[i].mx - K] * o;
		sum[a[i].mi]++;
		ins(a[i].mi, 1);
	}
	for(int i = l; i <= r; i++) sum[a[i].mi]--, ins(a[i].mi, -1);
}
void calc(int k) {
	tot = 0;
	dfs1(k, 0, n + 1, 0, 0);
	sort(a + 1, a + tot + 1, cmp);
	for(int i = 1; i <= tot; i++) if(a[i].mx - a[i].mi == K) ans++;
	ds(1, tot, 1);
	sort(a + 1, a + tot + 1, cmp1);
	int la = 1;
	for(int i = 1; i <= tot; i++) {
		if(i == tot || a[i].r != a[i + 1].r) {
			ds(la, i, -1);
			la = i + 1;
		}
	}
}
void solve(int k) {
	dfs(k, 0);
	s = si[k], mi = n + 1;
	find(k, 0);
	calc(rt);
	vi[rt] = 1;
	for(int i = last[rt]; i; i = nxt[i]) if(!vi[to[i]]) solve(to[i]);
}
int main() {
	int i, x, y, w;
	scanf("%d%d", &n, &K);
	for(i = 1; i < n; i++) {
		scanf("%d%d%d", &x, &y, &w);
		add(x, y, w), add(y, x, w);
	}
	solve(1);
	printf("%lld
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/LZA119/p/14279484.html