【JZOJ3360】【NOI2013模拟】苹果树

题目大意

给你一棵(n)个点的树,每个点有一种颜色;现在有(m)个询问,每次询问你(x)(y)的路径上,若将(a)颜色视作(b)颜色,不同的颜色有几种。

(nleq 50000,mleq 100000)

分析

如果是把问题放到序列上:询问区间([l,r])不同的颜色有几种。这个问题有两个已知的解法:

看这题的数据范围显然是让你莫队了。(雾

树上莫队的第一步,是把树上问题转换为序列问题。我们求出原树的欧拉序,可以发现这个序列有这样的性质:

将一个点在欧拉序中首次出现和第二次出现的位置分别记作(fir_u)(las_u),对于一条路径((x,y))(假定(fir_x<fir_y))。
(lca(x,y)=x),那么这条路径对应欧拉序中的区间([fir_x,fir_y])。但是区间中出现两次的点要去掉,因为它们不属于这条路径。
(lca(x,y) eq x),那么这条路径对应欧拉序中的区间([las_x,fir_y])。同样的要去掉出现两次的点,并且这个区间没有包括上(lca),要将(lca)再单独统计。

这样,树上问题就变成了序列问题。

为了不计算出现两次的点,我们开个标记数组,一个点每次出现,都把标记数组对应位置异或(1),那么一个点在标记数组中的值为(1)时才能被计算,当一个点对应的值变为(0)时又把它的贡献删去,这样问题便迎刃而解。再注意计算(lca)的答案即可。关于将(a)颜色视作(b)颜色的,只需判断区间中是否同时有(a)颜色和(b)颜色,有的话答案减(1),注意(a=b)要特判,不然要炸!

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 200007;

int n, m, col[N], ans[N], ord[N];
int tot, dfn, st[N], to[N << 1], nx[N << 1], fir[N], las[N], anc[N][17], dep[N];
void add(int u, int v) { to[++tot] = v, nx[tot] = st[u], st[u] = tot; }
void dfs(int u)
{
	fir[u] = ++dfn, ord[dfn] = u;
	for (int i = st[u]; i; i = nx[i]) if (!fir[to[i]]) anc[to[i]][0] = u, dep[to[i]] = dep[u] + 1, dfs(to[i]);
	las[u] = ++dfn, ord[dfn] = u;
}
int getlca(int u, int v)
{
	if (dep[u] < dep[v]) swap(u, v);
	for (int i = 16; i >= 0; i--) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if (u == v) return u;
	for (int i = 16; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
int block, ret, be[N], tag[N], buc[N];
struct note { int l, r, id, a, b, lca; } q[N];
int cmp(note a, note b) { return be[a.l] == be[b.l] ? ((be[a.l] & 1) ? a.r < b.r : a.r > b.r) : a.l < b.l; }
void ins(int c, int v)
{
	if (v == 1) { if (!buc[c]) ret++; buc[c]++; }
	else { buc[c]--; if (!buc[c]) ret--; }
}

int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
	for (int i = 1, u, v; i <= n; i++)
	{
		scanf("%d%d", &u, &v);
		if (u && v) add(u, v), add(v, u);
	}
	dep[1] = 1, dfs(1);
	for (int j = 1; j <= 16; j++) for (int i = 1; i <= n; i++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
	block = sqrt(2 * n);
	for (int i = 1; i <= 2 * n; i++) be[i] = i / block + 1;
	for (int i = 1, x, y, a, b, lca; i <= m; i++)
	{
		scanf("%d%d%d%d", &x, &y, &a, &b);
		if (fir[x] > fir[y]) swap(x, y);
		lca = getlca(x, y);
		if (lca == x) q[i] = (note){fir[x], fir[y], i, a, b, 0};
		else q[i] = (note){las[x], fir[y], i, a, b, lca};
	}
	sort(q + 1, q + m + 1, cmp);
	for (int i = 1, l = 1, r = 0; i <= m; i++)
	{
		while (l < q[i].l) tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]), ++l;
		while (l > q[i].l) --l, tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]);
		while (r < q[i].r) ++r, tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]);
		while (r > q[i].r) tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]), --r;
		if (q[i].lca) ins(col[q[i].lca], 1);
		ans[q[i].id] = ret;
		if (q[i].a != q[i].b && buc[q[i].a] && buc[q[i].b]) ans[q[i].id]--;
		if (q[i].lca) ins(col[q[i].lca], 0);
	}
	for (int i = 1; i <= m; i++) printf("%d
", ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/zjlcnblogs/p/11178379.html