P2633 Count on a tree(主席树)

只是转化成树上问题, 同样是动态开点维护

#include<bits/stdc++.h>
#define getsz(p) (p?p->sz:0)
#define getlsz(p) (p?getsz(p->ls):0)
#define getl(p) (p?p->ls:0)
#define getr(p) (p?p->rs:0)
using namespace std;
typedef long long ll;
const int N=4e5+10;
int a[N];
int depth[N],f[N][25];
int h[N],ne[N],e[N],idx;
int n;
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int st[N];
struct node{
    int l,r;
    int sz;
    node *ls,*rs;
    void update(){
        sz = getsz(ls) + getsz(rs);
    }
}*rt[N],pool[N*30];
vector<int> num;
node * copynode(node *rt){
    node *p=pool+(++idx);
    pool[idx]=*rt;
    return p;
}
node * newnode(int l,int r){
    node *p=pool+(++idx);
    p->l=l,p->r=r;
    return p;
}
node *insert(node *rt,int l,int r,int x){
    node *p;
    if(rt) p=copynode(rt);
    else p=newnode(l,r);
    p->sz++;
    int mid=l+r>>1;
    if(p->l==x&&p->r==x){
        return p;
    }
    if(x<=mid)
        p->ls=insert(p->ls,l,mid,x);
    else
        p->rs=insert(p->rs,mid+1,r,x);

    return p;
}
int find(int x){
    return lower_bound(num.begin(),num.end(),x)-num.begin()+1;
}
void dfs(int u,int fa){
    st[u]=1;
    rt[u]=insert(rt[fa],1,n,find(a[u]));
    int i;
    for(i=1;i<=20;i++){
        if(depth[u]<=(1<<i))
            break;
        f[u][i]=f[f[u][i-1]][i-1];
    }
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa||st[j])
            continue;
        depth[j]=depth[u]+1;
        f[j][0]=u;
        dfs(j,u);
    }
}
int lca(int a,int b){
    if(depth[a]<depth[b])
        swap(a,b);
    int i;
    for(i=20;i>=0;i--){
        if(depth[f[a][i]]>=depth[b]){
            a=f[a][i];
        }
    }
    if(a==b)
        return a;
    for(i=20;i>=0;i--){
        if(f[a][i]!=f[b][i]){
            a=f[a][i];
            b=f[b][i];
        }
    }
    return f[a][0];
}
int query(node* pL, node* pR, node* p0, node* p1, int k){
    if(pR && pR->l==pR->r) return pR->l;
    if(pL && pL->l==pL->r) return pL->l;

    int k1 = getlsz(pL) + getlsz(pR) - getlsz(p0) - getlsz(p1);
    if(k1 >= k) return query(getl(pL), getl(pR), getl(p0), getl(p1), k);
    else return query(getr(pL), getr(pR), getr(p0), getr(p1), k - k1);
}
int main(){
    ios::sync_with_stdio(false);
    int m;
    cin>>n>>m;
    int i;
    memset(h,-1,sizeof h);
    for(i=1;i<=n;i++){
        cin>>a[i];
        num.push_back(a[i]);
    }
    for(i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        add(u,v);
        add(v,u);
    }
    sort(num.begin(),num.end());
    num.erase(unique(num.begin(),num.end()),num.end());
    n=(int)num.size();
    int last=0;
    depth[1]=1;
    dfs(1,0);
    while(m--){
        int u,v,k;
        cin>>u>>v>>k;
        u^=last;
        int p=lca(u,v);
        last=num[query(rt[u],rt[v],rt[p],rt[f[p][0]],k)-1];
        cout<<last<<endl;
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/ctyakwf/p/13436573.html