树上路径(树链剖分)

来源:https://ac.nowcoder.com/acm/contest/22131/C

时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld

题目描述

给出一个n个点的树,1号节点为根节点,每个点有一个权值
你需要支持以下操作
1.将以u为根的子树内节点(包括u)的权值加val
2.将(u, v)路径上的节点权值加val
3.询问(u, v)路径上节点的权值两两相乘的和

输入描述:

第一行两个整数n, m,表示树的节点个数以及操作个数

接下来一行n个数,表示每个节点的权值

接下来n - 1行,每行两个整数(u, v),表示(u, v)之间有边

接下来m行
开始有一个数opt,表示操作类型
若opt = 1,接下来两个整数表示u, val
若opt = 2,接下来三个整数表示(u, v), val
若opt = 3,接下来两个整数表示(u, v)
含义均如题所示

输出描述:

对于每个第三种操作,输出一个数表示答案,对10^9+7取模
示例1

输入

3 8
5 3 1
1 2
1 3
3 1 2
3 1 3
3 2 3
1 1 2
2 1 3 2
3 1 2
3 1 3
3 2 3

输出

15
5
23
45
45
115
 
 
第一个和第二个操作板子就能解决, 第三个需要一些推导
$(x_1+x_2+x_3+...+x_n) ^ 2 = (x_1^2+x_2^2+x_3^2+...+x_i^2+...+x_n^2+...+x_1x_2+x_1x_3+...+x_ix_j+...+x_{n-1}x_n)$
$(sum_{i = l}^r x_i)^2 - sum_{i = l}^r x_i^2 = sum_{i = l}^rsum_{j = i wedge j eq i}^r 2x_ix_j$
后面的$sum_{i = l}^rsum_{j = i wedge j eq i}^r 2x_ix_j$即为我们想要的答案
于是我们可以维护区间元素的平方和跟一般和, sum1数组表示一般和, sum2数组表示平方和
但我们在维护的过程中会发现, 平方和似乎不是那么好维护, 因为题目里给定的操作还要给每个元素加上一个数值。
这时候就要思考lazy标记的作用了。
$sum_{l}^r(x+b)^2 \= sum_{l}^r(x^2+2bx+b^2) \= sum_{l}^rx^2 + 2bsum_{l}^rx + sum_{l}^rb^2 \= sum_{l}^rx^2 +2 bsum_{l}^rx + (r - l + 1)b^2$
至此, 结果就显而易见了
但是要注意模mod的时候不能使用除法, 而是要用逆元替代(千万要注意)
 
 
点击查看代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define IOS ios::sync_with_stdio(false), cin.tie(0), cout.tie(0)
using ll = long long;

constexpr int MAXN = 1e5 + 3, mod = 1e9 + 7;

int h[MAXN], e[MAXN << 1], ne[MAXN << 1], w[MAXN << 2], wt[MAXN << 2], Size[MAXN], mx[MAXN], idx, root;
int fa[MAXN], deep[MAXN], son[MAXN], id[MAXN], top[MAXN], n, cnt;
ll sum1[MAXN << 2], sum2[MAXN << 2], lazy[MAXN << 2], nid, inv2;
ll qpow(ll x, int n)
{
	ll ans = 1;
	while (n)
	{
		if (n & 1)
		{
			ans = ans * x % mod;
		}
		x = x * x % mod;
		n >>= 1;
	}
	return ans;
}
void addedge(int u, int v, int c = 0)
{
	++idx;
	e[idx] = v;
	ne[idx] = h[u];
	h[u] = idx;
}

void pushup(int rt)
{
	sum1[rt] = (sum1[rt << 1] + sum1[rt << 1 | 1]) % mod;
	sum2[rt] = (sum2[rt << 1] + sum2[rt << 1 | 1]) % mod;
}
void unionLazy(int rt)
{
	lazy[rt << 1] += lazy[rt];
	lazy[rt << 1 | 1] += lazy[rt];
}
void calLazy(int rt, int len)
{
	(sum2[rt << 1] += 2 * lazy[rt] * sum1[rt << 1] + (len - (len >> 1)) * lazy[rt] * lazy[rt]) %= mod;
	(sum1[rt << 1] += (len - (len >> 1)) * lazy[rt]) %= mod;

	(sum2[rt << 1 | 1] += 2 * lazy[rt] % mod * sum1[rt << 1 | 1] % mod + (len >> 1) * lazy[rt] * lazy[rt]) %= mod;
	(sum1[rt << 1 | 1] += (len >> 1) * lazy[rt]) %= mod;
}

void pushdown(int rt, int len)
{
	if (lazy[rt] == 0)
		return;

	calLazy(rt, len);
	unionLazy(rt);

	lazy[rt] = 0;
}

void build(int l, int r, int rt)
{
	lazy[rt] = 0;
	if (l == r)
	{
		sum1[rt] = wt[l];
		sum2[rt] = wt[l] * wt[l];
		return;
	}
	int mid = l + r >> 1;
	build(l, mid, rt << 1);
	build(mid + 1, r, rt << 1 | 1);
	pushup(rt);
}

ll query(int a, int b, int op, int l, int r, int rt)
{
	if (a <= l && r <= b)
		return op == 1 ? sum1[rt] : sum2[rt];
	pushdown(rt, r - l + 1);
	int mid = l + r >> 1;
	ll ans = 0;
	if (a <= mid)
	{
		ans = (ans + query(a, b, op, l, mid, rt << 1)) % mod;
	}
	if (b > mid)
	{
		ans = (ans + query(a, b, op, mid + 1, r, rt << 1 | 1)) % mod;
	}
	return ans;
}

void update(int a, int b, ll c, int l, int r, int rt)
{

	if (a <= l && r <= b)
	{
		(sum2[rt] += 2 * c * sum1[rt] + (r - l + 1) * c * c) %= mod;
		(sum1[rt] += (r - l + 1) * c) %= mod;
		lazy[rt] += c;
		return;
	}

	pushdown(rt, r - l + 1);
	int mid = l + r >> 1;
	if (a <= mid)
		update(a, b, c, l, mid, rt << 1);
	if (b > mid)
		update(a, b, c, mid + 1, r, rt << 1 | 1);
	pushup(rt);
}

ll pathquery(int x, int y)
{
	ll ans1 = 0, ans2 = 0;
	while (top[x] != top[y])
	{
		if (deep[top[x]] < deep[top[y]])
			swap(x, y);
		ans1 = (ans1 + query(id[top[x]], id[x], 1, 1, n, 1)) % mod;
		ans2 = (ans2 + query(id[top[x]], id[x], 2, 1, n, 1)) % mod;
		x = fa[top[x]];
	}

	if (deep[x] > deep[y])
		swap(x, y);
	ans1 = (ans1 + query(id[x], id[y], 1, 1, n, 1)) % mod;
	ans2 = (ans2 + query(id[x], id[y], 2, 1, n, 1)) % mod;
	return (ans1 * ans1 % mod - ans2 + mod) * inv2 % mod;
}

void lcaadd(int x, int y, ll c)
{
	while (top[x] != top[y])
	{
		if (deep[top[x]] < deep[top[y]])
			swap(x, y);
		update(id[top[x]], id[x], c, 1, n, 1);

		x = fa[top[x]];
	}
	if (deep[x] > deep[y])
		swap(x, y);
	update(id[x], id[y], c, 1, n, 1);
}

void sonadd(int x, ll c)
{
	update(id[x], id[x] + Size[x] - 1, c, 1, n, 1);
}

ll sonquery(int x, int op)
{
	return query(id[x], id[x] + Size[x] - 1, op, 1, n, 1) % mod;
}

void dfs1(int x, int f, int dep)
{
	deep[x] = dep;
	fa[x] = f;
	Size[x] = 1;
	int maxson = -1;
	for (int i = h[x]; i; i = ne[i])
	{
		int y = e[i];
		if (y != f)
		{
			dfs1(y, x, dep + 1);

			Size[x] += Size[y];
			if (maxson < Size[y])
			{
				son[x] = y;
				maxson = Size[y];
			}
		}
	}
}

void dfs2(int x, int topf)
{
	id[x] = ++cnt;
	wt[cnt] = w[x]; //必须根据访问顺序另开一个数组保存, 否则初始化答案会出错
	top[x] = topf;
	if (!son[x])
		return;
	dfs2(son[x], topf);
	for (int i = h[x]; i; i = ne[i])
	{
		int y = e[i];
		if (y != fa[x] && y != son[x])
			dfs2(y, y);
	}
}

int main()
{
	IOS;

	inv2 = qpow(2, mod - 2);
	int u, v, c, k, m, rt = 1;
	int x, y, z, op, val;

	cin >> n >> m;
	for (int i = 1; i <= n; ++i)
		cin >> w[i];

	for (int i = 1; i < n; ++i)
	{
		cin >> u >> v;
		addedge(u, v);
		addedge(v, u);
	}
	dfs1(rt, 0, 1);
	dfs2(rt, rt);
	build(1, n, 1);

	while (m--)
	{
		cin >> op >> x >> y;
		if (op == 1)
		{
			sonadd(x, y);
		}
		else if (op == 2)
		{
			cin >> val;
			lcaadd(x, y, val);
		}
		else
		{
			cout << pathquery(x, y) << '
';
		}
	}

	return 0;
}
原文地址:https://www.cnblogs.com/daremo/p/15488901.html