[湖南集训]谈笑风生

题解

套路套路地用线段树合并
注意:可能爆栈,所以 ( ext {bfs}) 处理
合并要新开节点,不然后修改子树信息

(Code)

#include<cstdio>
#include<iostream>
#define LL long long
using namespace std;

const int N = 3e5 + 5;
int n, q, h[N], tot;

struct edge{int to, nxt;}e[N << 1];
inline void add(int x, int y){e[++tot] = edge{y, h[x]}, h[x] = tot;}

inline void read(int &x)
{
	x = 0; char ch = getchar(); int f = 1;
	while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
	while (ch >= '0' && ch <= '9') x = (x<<3)+(x<<1)+ch-'0', ch = getchar();
	x *= f;
}

int fa[N], siz[N], dep[N], rt[N], size;
struct Tree{LL sum; int ls, rs;}seg[60 * N];
void insert(int &p, int l, int r, int x, int v)
{
	if (!p) p = ++size;
	seg[p].sum += v;
	if (l == r) return;
	int mid = (l + r) >> 1;
	if (x <= mid) insert(seg[p].ls, l, mid, x, v);
	else insert(seg[p].rs, mid + 1, r, x, v);
}
int merge(int x, int y)
{
	if (!x || !y) return x | y;
	int p = ++size;
	seg[p].sum = seg[x].sum + seg[y].sum;
	seg[p].ls = merge(seg[x].ls, seg[y].ls);
	seg[p].rs = merge(seg[x].rs, seg[y].rs);
	return p;
}
LL query(int p, int l, int r, int x, int y)
{
	if (x <= l && r <= y) return seg[p].sum;
	int mid = (l + r) >> 1; LL res = 0;
	if (x <= mid && seg[p].ls) res += query(seg[p].ls, l, mid, x, y);
	if (y > mid && seg[p].rs) res += query(seg[p].rs, mid + 1, r, x, y);
	return res;
}

int d[N];
void bfs()
{
	int head = 0, tail = 1;
	d[1] = 1, dep[1] = 1;
	while (head < tail)
	{
		int x = d[++head];
		for(register int i = h[x]; i; i = e[i].nxt)
		{
			if (e[i].to == fa[x]) continue;
			fa[e[i].to] = x, d[++tail] = e[i].to, dep[e[i].to] = dep[x] + 1;
		}
	}
	for(register int j = n; j; j--)
	{
		int x = d[j]; siz[x] = 1;
		for(register int i = h[x]; i; i = e[i].nxt)
		{
			if (e[i].to == fa[x]) continue;
			siz[x] += siz[e[i].to];
		}
		insert(rt[x], 1, n, dep[x], siz[x] - 1);
		for(register int i = h[x]; i; i = e[i].nxt)
		{
			if (e[i].to == fa[x]) continue;
			rt[x] = merge(rt[x], rt[e[i].to]);
		}
	}
}

int main()
{
	read(n), read(q);
	for(register int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
	bfs();
	for(register int i = 1, p, k; i <= q; i++)
	{
		read(p), read(k);
		printf("%lld
", 1LL * min(dep[p] - 1, k) * (siz[p] - 1) + query(rt[p], 1, n, dep[p] + 1, dep[p] + k));
	}
}
原文地址:https://www.cnblogs.com/leiyuanze/p/14328639.html