【luogu P2590 [ZJOI2008]树的统计】 题解

题目链接:https://www.luogu.org/problemnew/show/P2590
我想学树剖QAQ

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 31000;
int fa[maxn], dep[maxn], size[maxn], son[maxn], top[maxn], seg[maxn], rev[maxn<<2];
int sum[maxn<<2], num[maxn], mx[maxn];
struct edge{
	int to, next;
}e[maxn<<2];
int head[maxn<<2], cnt;
int summ, maxx, n, m;
void query(int k, int l, int r, int L, int R)//区间查询
{
	if(L > r||R < l) return;
	if(L <= l&&r <= R)
	{
		summ += sum[k];
		maxx = max(maxx, mx[k]);
		return;
	}
	int mid = l+r>>1, res = 0;
	if(mid >= L) query(k<<1, l, mid, L, R);
	if(mid+1 <= R) query((k<<1)+1, mid+1, r, L, R);
}
void change(int k, int l, int r, int val, int pos)//单点修改 
{
	if(pos>r || pos<l) return;
	if(l == r && r == pos)
	{
		sum[k] = val;
		mx[k] = val;
		return;
	}
	int mid = l+r>>1;
	if(mid >= pos) change(k<<1, l, mid, val, pos);
	if(mid+1 <= pos) change((k<<1)+1, mid+1, r, val, pos);
	sum[k] = sum[k<<1]+sum[(k<<1)+1];
	mx[k] = max(mx[k<<1], mx[(k<<1)+1]);
}
void dfs1(int u, int f)
{
	int v;
	size[u] = 1;
	fa[u] = f;
	dep[u] = dep[f]+1;
	for(int i = head[u]; v = e[i].to, i; i = e[i].next)
	{
		if(v != f)
		{
			dfs1(v,u);
			size[u] += size[v];
			if(size[v] > size[son[u]])
			son[u] = v;
		}
	}
}
void dfs2(int u, int f)
{
	int v;
	if(son[u])
	{
		seg[son[u]] = ++seg[0];
		top[son[u]] = top[u];
		rev[seg[0]] = son[u];
		dfs2(son[u],u);
	}
	for(int i = head[u]; v = e[i].to, i; i = e[i].next)
	{
		if(!top[v])
		{
			seg[v] = ++seg[0];
			rev[seg[0]] = v;
			top[v] = v;
			dfs2(v,u);
		}
	}
}
void build(int k, int l, int r)
{
	int mid = l+r>>1;
	if(l == r)
	{
		mx[k] = sum[k] = num[rev[l]];
		return;
	}
	build(k<<1, l, mid);
	build((k<<1)+1, mid+1, r);
	sum[k] = sum[k<<1]+sum[(k<<1)+1];
	mx[k] = max(mx[k<<1],mx[(k<<1)+1]);
}
inline int read()
{
	char c;
	int k = 1;
	while((c = getchar())<'0' || c>'9')
		if(c == '-') k = -1;
	int res = c-'0';
	while((c = getchar())>='0' && c<='9')
		res = res*10+c-'0';
	return res*k;
}
inline void add(int u, int v)
{
	e[++cnt].next = head[u];
	e[cnt].to = v;
	head[u] = cnt;
}
inline void insert(int u, int v)
{
	add(u, v);
	add(v, u); 
}
inline void ask(int u, int v)
{
	int fu = top[u], fv = top[v];
	while(fu != fv)
	{
		if(dep[fu]<dep[fv]) swap(u,v), swap(fu,fv);
		query(1,1,seg[0],seg[fu],seg[u]);
		u = fa[fu], fu = top[u];
	}
	if(dep[u]>dep[v]) swap(u,v);
	query(1,1,seg[0],seg[u],seg[v]);
}
int main()
{
	n = read();
	for(int i = 1; i < n; i++)
	insert(read(),read());
	for(int i = 1; i <= n; i++)
	num[i] = read();
	dfs1(1,0);
	seg[0] = seg[1] = rev[1] = top[1] = 1;
	dfs2(1,0);
	build(1,1,seg[0]);
	m = read();
	char opt[10];
	int u, v;
	for(int i = 1; i <= m; i++)
	{
		scanf("%s",opt+1);
		u = read(); v = read();
		if(opt[1] == 'C')
			change(1,1,seg[0],v,seg[u]);
		else
		{
			summ = 0;
			maxx = -0x7fffffff;
			ask(u,v);
			if(opt[2] == 'M')
			printf("%d
",maxx);
			else
			printf("%d
",summ);	
		}		
	}
	return 0;
}
原文地址:https://www.cnblogs.com/MisakaAzusa/p/9223310.html