NOI.AC 722: tree

就贴个代码

#include <cstdio>
#include <algorithm>

typedef long long LL;
const int MN = 200005, MS = 524289;

int N, Q, h[MN], nxt[MN * 2], to[MN * 2], tot;
inline void ins(int x, int y) { nxt[++tot] = h[x], to[tot] = y, h[x] = tot; }

int dep[MN], faz[MN], siz[MN], son[MN], top[MN], ldf[MN], rdf[MN], idf[MN], dfc;
void DFS0(int u, int fz) {
	dep[u] = dep[faz[u] = fz] + 1, siz[u] = 1;
	for (int i = h[u]; i; i = nxt[i]) {
		if (to[i] == fz) continue;
		DFS0(to[i], u);
		siz[u] += siz[to[i]];
		if (siz[son[u]] < siz[to[i]]) son[u] = to[i];
	}
}
void DFS1(int u, int t) {
	top[u] = t, idf[ldf[u] = ++dfc] = u;
	if (son[u]) DFS1(son[u], t);
	for (int i = h[u]; i; i = nxt[i]) {
		if (to[i] == faz[u] || to[i] == son[u]) continue;
		DFS1(to[i], to[i]);
	}
	rdf[u] = dfc;
}
inline int LCA(int u, int v) {
	while (top[u] != top[v]) {
		if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
		u = faz[top[u]];
	}
	return dep[u] < dep[v] ? u : v;
}
inline int Dist(int u, int v) { return dep[u] + dep[v] - 2 * dep[LCA(u, v)]; }

#define li (i << 1)
#define ri (i << 1 | 1)
#define mid ((l + r) >> 1)
#define ls li, l, mid
#define rs ri, mid + 1, r
namespace T1 {
	int len[MS];
	LL sa[MS], sb[MS], tgk[MS], tgb[MS], tgv[MS];
	bool tg[MS];
	inline void P(int i, LL k, LL b, LL v) {
		sa[i] += (LL)(len[i] + 1) * len[i] / 2 * k + len[i] * b + sb[i] * v;
		tgk[i] += k, tgb[i] += b, tgv[i] += v;
		tg[i] = 1;
	}
	inline void PushDown(int i) {
		if (tg[i]) {
			P(li, tgk[i], tgb[i], tgv[i]);
			P(ri, tgk[i], tgb[i] + len[li] * tgk[i], tgv[i]);
			tgk[i] = tgb[i] = tgv[i] = 0;
			tg[i] = 0;
		}
	}
	void Build(int i, int l, int r) {
		len[i] = r - l + 1;
		if (l == r) { sb[i] = dep[idf[l]]; return ; }
		Build(ls), Build(rs);
		sb[i] = sb[li] + sb[ri];
	}
	void Mdf1(int i, int l, int r, int a, int b, LL x, LL y) {
		if (r < a || b < l) return ;
		if (a <= l && r <= b) return P(i, x, y + (l - a) * x, 0);
		PushDown(i);
		Mdf1(ls, a, b, x, y), Mdf1(rs, a, b, x, y);
		sa[i] = sa[li] + sa[ri];
	}
	void Mdf2(int i, int l, int r, int a, int b, LL x) {
		if (r < a || b < l) return ;
		if (a <= l && r <= b) return P(i, 0, 0, x);
		PushDown(i);
		Mdf2(ls, a, b, x), Mdf2(rs, a, b, x);
		sa[i] = sa[li] + sa[ri];
	}
	LL Qur(int i, int l, int r, int a, int b) {
		if (r < a || b < l) return 0ll;
		if (a <= l && r <= b) return sa[i];
		PushDown(i);
		return Qur(ls, a, b) + Qur(rs, a, b);
	}

	inline void ChainAdd(int x, int y, LL k, LL b) {
		int len = Dist(x, y) + 1;
		while (top[x] != top[y]) {
			if (dep[top[x]] < dep[top[y]]) {
				b += (len + 1) * k, k = -k;
				std::swap(x, y);
			}
			Mdf1(1, 1, N, ldf[top[x]], ldf[x], -k, b + (dep[x] - dep[top[x]] + 2) * k);
			b += (dep[x] - dep[top[x]] + 1) * k;
			len -= dep[x] - dep[top[x]] + 1;
			x = faz[top[x]];
		}
		if (dep[x] > dep[y]) {
			b += (len + 1) * k, k = -k;
			std::swap(x, y);
		}
		Mdf1(1, 1, N, ldf[x], ldf[y], k, b);
	}
	
	inline LL ChainQur(int x, int y) {
		LL Sum = 0;
		while (top[x] != top[y]) {
			if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
			Sum += Qur(1, 1, N, ldf[top[x]], ldf[x]);
			x = faz[top[x]];
		}
		if (dep[x] > dep[y]) std::swap(x, y);
		return Sum + Qur(1, 1, N, ldf[x], ldf[y]);
	}
}

namespace T2 {
	int len[MS];
	LL s1[MS], s2[MS], s3[MS], sb[MS], sc[MS], tg[MS];
	inline void P(int i, LL x) {
		s1[i] += len[i] * x;
		s2[i] += sb[i] * x;
		s3[i] += sc[i] * x;
		tg[i] += x;
	}
	inline void PushDown(int i) {
		if (tg[i]) P(li, tg[i]), P(ri, tg[i]), tg[i] = 0;
	}
	void Build(int i, int l, int r) {
		len[i] = r - l + 1;
		if (l == r) { sb[i] = siz[idf[l]], sc[i] = dep[idf[l]]; return ; }
		Build(ls), Build(rs);
		sb[i] = sb[li] + sb[ri];
		sc[i] = sc[li] + sc[ri];
	}
	void Mdf(int i, int l, int r, int a, int b, LL x) {
		if (r < a || b < l) return ;
		if (a <= l && r <= b) return P(i, x);
		PushDown(i);
		Mdf(ls, a, b, x), Mdf(rs, a, b, x);
		s1[i] = s1[li] + s1[ri];
		s2[i] = s2[li] + s2[ri];
		s3[i] = s3[li] + s3[ri];
	}
	LL Qur(int i, int l, int r, int a, int b, int t) {
		if (r < a || b < l) return 0;
		if (a <= l && r <= b) return (t == 1 ? s1 : t == 2 ? s2 : s3)[i];
		PushDown(i);
		return Qur(ls, a, b, t) + Qur(rs, a, b, t);
	}
	inline void ChainAdd(int x, LL v) {
		for (; x; x = faz[top[x]])
			Mdf(1, 1, N, ldf[top[x]], ldf[x], v);
	}
	inline LL ChainQur(int x, int t) {
		LL Sum = 0;
		for (; x; x = faz[top[x]])
			Sum += Qur(1, 1, N, ldf[top[x]], ldf[x], t);
		return Sum;
	}
}

int main() {
	int op, x, y, z; LL v;
	scanf("%d", &N);
	for (int i = 1; i < N; ++i) scanf("%d%d", &x, &y), ins(x, y), ins(y, x);
	DFS0(1, 0), DFS1(1, 1);
	T1::Build(1, 1, N);
	T2::Build(1, 1, N);
	scanf("%d", &Q);
	while (Q--) {
		scanf("%d%d", &op, &x);
		if (op == 1) {
			scanf("%d%lld", &y, &v);
			T1::ChainAdd(x, y, 0, v);
		}
		if (op == 2) {
			scanf("%lld", &v);
			T1::Mdf1(1, 1, N, ldf[x], rdf[x], 0, v);
		}
		if (op == 3) {
			scanf("%d%d%lld", &y, &z, &v);
			int l1 = LCA(x, y), l2 = LCA(x, z), l3 = LCA(y, z);
			int pos = dep[l1] > dep[l2] ? dep[l1] > dep[l3] ? l1 : l3 : dep[l2] > dep[l3] ? l2 : l3;
			int dist = Dist(z, pos);
			LL b = (dist - 1) * v;
			T1::ChainAdd(pos, x, v, b);
			T1::ChainAdd(pos, y, v, b);
			T1::Mdf1(1, 1, N, ldf[pos], ldf[pos], 0, -dist * v);
		}
		if (op == 4) {
			scanf("%d%lld", &z, &v);
			if (ldf[x] < ldf[z] && ldf[z] <= rdf[x]) {
				T1::Mdf2(1, 1, N, ldf[x], rdf[x], v);
				T1::Mdf1(1, 1, N, ldf[x], rdf[x], 0, (Dist(z, x) - dep[x]) * v);
				T2::ChainAdd(z, -2 * v);
				T2::ChainAdd(x, 2 * v);
			} else {
				T1::Mdf2(1, 1, N, ldf[x], rdf[x], v);
				T1::Mdf1(1, 1, N, ldf[x], rdf[x], 0, (Dist(z, x) - dep[x]) * v);
			}
		}
		if (op == 5) {
			LL Ans = T1::Qur(1, 1, N, ldf[x], ldf[x]);
			Ans += T2::ChainQur(x, 1);
			printf("%lld
", Ans);
		}
		if (op == 6) {
			scanf("%d", &y);
			int l = LCA(x, y), len = dep[x] + dep[y] - 2 * dep[l] + 1;
			LL Ans = T1::ChainQur(x, y);
			LL Ql1 = T2::ChainQur(l, 1), Qx1 = T2::ChainQur(x, 1), Qy1 = T2::ChainQur(y, 1);
			LL Ql3 = T2::ChainQur(l, 3), Qx3 = T2::ChainQur(x, 3), Qy3 = T2::ChainQur(y, 3);
			Ans += Ql1 * len;
			Ans += (Qx1 - Ql1) * (dep[x] + 1) - (Qx3 - Ql3);
			Ans += (Qy1 - Ql1) * (dep[y] + 1) - (Qy3 - Ql3);
			printf("%lld
", Ans);
		}
		if (op == 7) {
			LL Ans = T1::Qur(1, 1, N, ldf[x], rdf[x]);
			Ans += T2::ChainQur(x, 1) * siz[x];
			if (ldf[x] < rdf[x]) Ans += T2::Qur(1, 1, N, ldf[x] + 1, rdf[x], 2);
			printf("%lld
", Ans);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/PinkRabbit/p/11617513.html