Count on a tree「主席树」

Count on a tree「主席树」

题目描述

给定一棵 (n) 个节点的树,每个点有一个权值。有 (m)个询问,每次给你 (u,v,k),你需要回答 (u) (xor) (last)(v) 这两个节点间第(k) 小的点权。

其中 (last) 是上一个询问的答案,定义其初始为 (0),即第一个询问的 (u) 是明文。

输入格式

第一行两个整数 (n,m)

第二行有 (n) 个整数,其中第 (i) 个整数表示点 (i) 的权值。

后面 (n−1) 行每行两个整数 (x,y),表示点 (x)到点 (y) 有一条边。

最后 (m) 行每行两个整数 (u,v,k),表示一组询问。

输出格式

(m) 行,每行一个正整数表示每个询问的答案。

输入输出样例

输入 #1

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

输出 #1

2
8
9
105
7

说明/提示

【数据范围】
对于 (100%) 的数据,(1≤n,m≤10^{5})

思路分析

又是因为一道题而对算法有了更深刻的认识

  • 区间第k小,显然主席树(其他的方法本蒟蒻暂且还不会)
  • 关键点是,主席树虽然说是树,但它到底还是线段树的变形,线段树处理的是什么?一个序列啊,可这道题确是在树上进行操作,还能用吗?这就需要对主席树有透彻的理解(这也是我看很多题解认真思考一直没明白的地方)

关于主席树:

  • 作为可持续化衍生出的数据结构,主席树的核心之一就是多个历史版本,这个应该不用多说了,关键是另一个核心——前缀和
  • 实现:每个节点根据前一个节点建立。然后利用前缀和思想拿第(r)棵树 - 第(l-1)棵树得到([l,r])区间的信息,操作在这棵新树上操作即可。
  • 那么如果是一个线性的序列,我们只需要将其离散化后对每个点建立权值线段树即可,然后根据前缀和思想即可求解(推荐这个板子题P3834 【模板】可持久化线段树 2(主席树)
  • 还是回到这个问题,树上进行操作该如何处理?考虑树的结构,树……边……节点……(突然明了)我们对每个节点在它父亲基础上建树不就行了!。这样每一个节点所保存的信息就是它自身到根节点这条链上的信息了,妙啊!。
  • 接下来类似于序列的处理,将树上分为几个链,用到树链剖分,那x-y路径的信息就是(siz[u]+siz[v]−siz[lca(u,v)]−siz[fa[lca(u,v)]])

Code

#include<bits/stdc++.h>
#define N 200010
using namespace std;
inline int read(){
	int x = 0,f =  1;
	char ch = getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
struct Tree{
	int l, r, sum;
}tree[N<<5];
struct edge{
	int next, to;
}e[N * 2];
int n, m,tot,cnt,dex,last;
int h[N], a[N], b[N], fa[N], dep[N];
int son[N], top[N], size[N], rt[N];
int find(int x){
	return lower_bound(b + 1, b + 1 + cnt, x) - b;
}
void add(int u, int v){
    e[++tot].next = h[u];
    e[tot].to = v;
    h[u] = tot;
}
void dfs1(int x, int fath, int depth){ //两个dfs跑树剖
    size[x] = 1, fa[x] = fath,dep[x] = depth;
    int maxSon = 0;
    for(int i = h[x]; i != 0; i = e[i].next)
        if(e[i].to != fath){
            dfs1(e[i].to, x, depth + 1);
            size[x] += size[e[i].to];
            if(size[e[i].to] > maxSon){
                maxSon = size[e[i].to];
                son[x] = e[i].to;
            }
        }
}

void dfs2(int x, int head){
    top[x] = head;
    if(!son[x]) return;
    dfs2(son[x], head);
    for(int i = h[x]; i != 0; i = e[i].next)
        if(e[i].to != fa[x] && e[i].to != son[x])
            dfs2(e[i].to, e[i].to);
}

int lca(int x, int y){//树剖求lca
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if(dep[x] < dep[y]) return x;
    return y;
}

int build(int l, int r){
    int p = ++dex;//开点
    int mid = (l + r) >> 1;
    if(l == r) return p; //叶子节点
    tree[p].l = build(l, mid);
    tree[p].r = build(mid + 1, r);
    return p;
}

int update(int pre, int l, int r, int val){
    int p = ++dex;//还是开点
    int mid = (l + r) >> 1;
    tree[p].l = tree[pre].l;//复制上一个版本
    tree[p].r = tree[pre].r;
    tree[p].sum = tree[pre].sum + 1;//这个当然不能复制啦
    if(l == r) return p;
    if(val <= mid) tree[p].l = update(tree[pre].l, l, mid, val);
    else tree[p].r = update(tree[pre].r, mid + 1, r, val);
    return p;
}

void dfs(int x){
    rt[x] = update(rt[fa[x]], 1, cnt, find(a[x]));
    for(int i = h[x]; i != 0; i = e[i].next){
        if(e[i].to != fa[x]){
        	dfs(e[i].to);
        }
    }
}

int query(int s1, int s2, int fa, int pa, int l, int r, int rank){
    int size = tree[tree[s2].l].sum + tree[tree[s1].l].sum - tree[tree[fa].l].sum - tree[tree[pa].l].sum;
    int mid = (l + r) >> 1;
    if(l == r) return l;
    if(rank <= size) return query(tree[s1].l, tree[s2].l, tree[fa].l, tree[pa].l, l, mid, rank);
    else return query(tree[s1].r, tree[s2].r, tree[fa].r, tree[pa].r, mid + 1, r, rank - size);
}

int main(){
    n = read(),m = read();
    for(int i = 1; i <= n; i++){
        a[i] = read();
        b[++cnt] = a[i];
    }
    sort(b + 1, b + 1 + cnt);
    cnt = unique(b + 1, b + 1 + cnt)-b-1; //记得离散化
    for(int i = 1; i < n; i++){
        int u = read(), v = read();
        add(u, v),add(v, u);
    }
    rt[0] = build(1, cnt);
    dfs1(1, 0, 1), dfs2(1, 1), dfs(1);
    for(int i = 1; i <= m; i++){
        int u = read() ^ last, v = read(), rank = read();
        int  head = lca(u, v);
        int id = query(rt[u], rt[v], rt[head], rt[fa[head]], 1, cnt, rank);
        printf("%d
",b[id]);
        last = b[id];
    }
    return 0;
}

原文地址:https://www.cnblogs.com/hhhhalo/p/13368387.html