这题是Silver T3的扩展。
有了ST3的提示,我们知道只要确定一条链上某种颜色的个数即可。
显然这是可以在线做的(虽然并不需要)。
看到网上的题解大多是用主席树写的。我就介绍一种用线段树合并的做法,其实思路是一样的。
在每个节点维护一棵线段树,每个叶节点代表从根到此节点的一种颜色的个数。
首先这课线段树肯定是要动态开点的。
对于每个节点,先将它自己的颜色插入它对应的线段树。
for (int i = 1; i <= n; i++) {
cin >> t[i];
add(rt[i], 1, n, t[i]);
}
然后我们递归。每递归到一个节点 (x),我们就将它父亲的线段树合并到它的线段树(感觉就像主席树换了种写法。。。)
对于每组询问 ((a, b, c)),我们通过线段树查询(a),(b) 和 (lca(a, b))可以求出从根分别到它们的路径上 (c) 的个数,然后一加一减就可以知道 (a->b) 的路径上 (c) 的个数。
这道题就做完了。
完整代码:
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 100010;
struct seg_tree{
int val, lson, rson;
}st[N * 100];
int len, rt[N];
struct node{
int pre, to;
}edge[2 * N];
int head[N], tot;
int n, m;
int t[N];
int depth[N], f[N][25];
int lca(int x, int y) {
if (depth[x] < depth[y]) swap(x, y);
int d = depth[x] - depth[y];
for (int i = 0; i <= 20; i++) {
if ((1 << i) & d) {
x = f[x][i];
d -= (1 << i);
}
}
if (x == y) return x;
for (int i = 20; i >= 0; i--) {
if (f[x][i] != f[y][i]) {
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
void ad(int u, int v) {
edge[++tot] = node{head[u], v};
head[u] = tot;
}
void add(int &x, int l, int r, int v) {
if (!x) x = ++len;
if (l == r) {
st[x].val++;
return;
}
int mid = (l + r) >> 1;
if (v <= mid) add(st[x].lson, l, mid, v);
else add(st[x].rson, mid + 1, r, v);
st[x].val = st[st[x].lson].val + st[st[x].rson].val;
}
int merge(int u, int v) {
if (!u) return v;
if (!v) return u;
st[u].val += st[v].val;
st[u].lson = merge(st[u].lson, st[v].lson);
st[u].rson = merge(st[u].rson, st[v].rson);
return u;
}
void dfs(int x, int fa) {
for (int i = 1; i <= 20; i++) {
f[x][i] = f[f[x][i - 1]][i - 1];
}
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == fa) continue;
f[y][0] = x;
depth[y] = depth[x] + 1;
rt[y] = merge(rt[y], rt[x]);
dfs(y, x);
}
}
int ask(int x, int l, int r, int p) {
if (!x) return 0;
if (l == r) return st[x].val;
int mid = (l + r) >> 1;
if (p <= mid) return ask(st[x].lson, l, mid, p);
else return ask(st[x].rson, mid + 1, r, p);
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> t[i];
add(rt[i], 1, n, t[i]);
}
for (int i = 1, a, b; i < n; i++) {
cin >> a >> b;
ad(a, b);
ad(b, a);
}
dfs(1, 0);
for (int i = 1, a, b, c; i <= m; i++) {
cin >> a >> b >> c;
int LCA = lca(a, b);
if (ask(rt[a], 1, n, c) + ask(rt[b], 1, n, c) - ask(rt[LCA], 1, n, c) - ask(rt[f[LCA][0]], 1, n, c) > 0) cout << 1;
else cout << 0;
}
return 0;
}