树的统计

题目描述

思路

树链剖分模板

代码

#include <cstdio>
#include <cstring>
#define lc k<<1
#define rc k<<1|1
#define max(a, b) ((a) > (b) ? a : b)
#define FOR(a, b) for(int i=a;i<=b;++i)

const int MAX = 3e4 + 5;
int n, ot, oa[20], inf = 0x3f3f3f3f;
char it[10];
int head[MAX], ver[MAX << 1], nt[MAX << 1], ht;
int wt[MAX], zt;
char str[10];
int fa[MAX], son[MAX], size[MAX], dep[MAX];
int dfn[MAX], dt, tr[MAX], top[MAX];
int mx[MAX << 2], sum[MAX << 2];
int ans_mx, ans_sum;
char showStr[100];
inline int read() {
	int s = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		if (ch == '-') f = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

inline void write(int x) {
	ot = 0;
	if (x == 0) { putchar('0'); return; }
	if (x < 0) putchar('-'), x = -x;
	while (x) oa[++ot] = x % 10 + '0', x /= 10;
	while (ot) putchar(oa[ot--]);
}
void add(int x, int y) {
	nt[++ht] = head[x], head[x] = ht, ver[ht] = y;
}

void dfs1(int x, int u) {
	fa[x] = u;
	dep[x] = dep[u] + 1;
	size[x] = 1;
	for (int i = head[x], j; i; i = nt[i]) {
		j = ver[i];
		if (j == u) continue;
		dfs1(j, x);
		size[x] += size[j];
		if (size[j] > size[son[x]]) son[x] = j;
	}
}

void dfs2(int x, int u) {
	dfn[x] = ++dt;
	tr[dt] = x;
	top[x] = u;
	if (son[x]) {
		dfs2(son[x], u);
	}
	for (int i = head[x], j; i; i = nt[i]) {
		j = ver[i];
		if (!dfn[j]) dfs2(j, j);
	}
}

void build(int k, int l, int r) {
	if (l == r) {
		sum[k] = mx[k] = wt[tr[l]];
		return;
	}
	int mid = l + r >> 1;
	build(lc, l, mid);
	build(rc, mid + 1, r);
	sum[k] = sum[lc] + sum[rc];
	mx[k] = max(mx[lc], mx[rc]);
}

void change(int k, int l, int r, int x, int y) {
	if (l == r && x == l) {
		sum[k] = mx[k] = y;
		return;
	}
	int mid = l + r >> 1;
	if (x <= mid) change(lc, l, mid, x, y);
	else change(rc, mid + 1, r, x, y);
	mx[k] = max(mx[lc], mx[rc]);
	sum[k] = sum[lc] + sum[rc];
}

void query(int k, int l, int r, int x, int y) {
	// printf("query: %d %d %d %d %d
", k, l, r, x, y);
	if (x <= l && r <= y) {
		ans_mx = max(ans_mx, mx[k]);
		ans_sum += sum[k];
		return;
	} 
	int mid = l + r >> 1;
	if (x <= mid) query(lc, l, mid, x, y);
	if (y > mid) query(rc, mid + 1, r, x, y);
}

void swap(int &x, int &y) {
	int t = x;
	x = y;
	y = t;
}
void ask(int x, int y) {
	int fx = top[x], fy = top[y];
	while (fx != fy) {
		if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
		query(1, 1, n, dfn[fx], dfn[x]);
		x = fa[fx], fx = top[x];
	}
	if (dep[x] > dep[y]) swap(x, y);
	query(1, 1, n, dfn[x], dfn[y]);
}
void showArray(int arr[]) {
	printf("%s", showStr);
	for (int i = 1; i <= n; ++i) printf("%2d ", arr[i]);
	puts("");
}
void show() {
	printf("n: %d
", n);
	for (int i = 1; i <= n; ++i) {
		printf("%d:", i);
		for (int j = head[i]; j; j = nt[j]) {
			printf("%d ", ver[j]);
		}
		puts("");
	}
	printf("zt: %d
", zt);
	strcpy(showStr, "weight:"),showArray(wt);
	strcpy(showStr, "dep   :"),showArray(dep);
	strcpy(showStr, "fa    :"),showArray(fa);
	strcpy(showStr, "size  :"),showArray(size);
	strcpy(showStr, "son   :"),showArray(son);
	strcpy(showStr, "dfn   :"),showArray(dfn);
	strcpy(showStr, "top   :"),showArray(top);
	strcpy(showStr, "tr    :"),showArray(tr);
}
int main() {
	n = read();
	for (int i = 1, a, b; i < n; ++i) {
		a = read(), b = read();
		add(a, b), add(b, a);
	}
	for (int i = 1; i <= n; ++i) wt[i] = read();
	zt = read();
	dfs1(1, 0);
	dfs2(1, 1);
	// show();
	build(1, 1, n);
	for (int i = 1, x, y; i <= zt; ++i) {
		scanf("%s", str);
		x = read(), y = read();
		ans_sum = 0, ans_mx = -inf;
		if (strcmp(str, "QMAX") == 0) {
			ask(x, y);
			write(ans_mx);
			puts("");
		}
		else if (strcmp(str, "QSUM") == 0) {
			ask(x, y);
			write(ans_sum);
			puts("");
		}
		else change(1, 1, n, dfn[x], y); 
	}
	return 0;
}
原文地址:https://www.cnblogs.com/liuzz-20180701/p/11521434.html