Count on a tree

题目大意:树上第k小的数。

思路:和区间第k小的数的做法差不多,不过要求一下lca,比较麻烦。

  1 #include<bits/stdc++.h>
  2 #define fi first
  3 #define se second
  4 #define pb push_back
  5 #define mk make_pair
  6 #define pii pair<int,int>
  7 #define read(x) scanf("%d",&x)
  8 #define lread(x) scanf("%lld",&x)
  9 using namespace std;
 10 
 11 typedef long long ll;
 12 const int N=1e5+7;
 13 const int M=1e5+1;
 14 const int inf=0x3f3f3f3f;
 15 const ll INF=0x3f3f3f3f3f3f3f3f;
 16 
 17 int n,m,top,root[N],fa[N],a[N],hs[N],vs[N<<1],dp[N<<1][25],tot,f[N],d[N];
 18 vector<int> e[N];
 19 struct seg_tree
 20 {
 21     int cnt=0;
 22     struct node
 23     {
 24         int l,r,sum;
 25     }a[N*20];
 26     void updata(int l,int r,int &x,int y,int v)
 27     {
 28         a[++cnt]=a[y]; x=cnt; a[x].sum++;
 29         if(l==r)
 30             return;
 31         int mid=(l+r)>>1;
 32         if(v<=hs[mid])
 33             updata(l,mid,a[x].l,a[y].l,v);
 34         else
 35             updata(mid+1,r,a[x].r,a[y].r,v);
 36     }
 37     int query(int l,int r,int x,int y,int z,int t,int k)
 38     {
 39         if(l==r)
 40             return hs[l];
 41         int mid=(l+r)>>1;
 42         int ans=a[a[x].l].sum+a[a[y].l].sum-a[a[z].l].sum-a[a[t].l].sum;
 43         if(ans>=k)
 44             return query(l,mid,a[x].l,a[y].l,a[z].l,a[t].l,k);
 45         else
 46             return query(mid+1,r,a[x].r,a[y].r,a[z].r,a[t].r,k-ans);
 47 
 48     }
 49 }seg;
 50 void dfs(int u,int p)
 51 {
 52     fa[u]=p;
 53     seg.updata(1,top,root[u],root[p],a[u]);
 54     for(int v:e[u])
 55         if(v!=p) dfs(v,u);
 56 }
 57 void dfs(int u,int p,int s)
 58 {
 59     vs[++tot]=u;
 60     d[u]=s;
 61     f[u]=tot;
 62     for(int i=0;i<e[u].size();i++)
 63     {
 64         int to=e[u][i];
 65         if(to==p) continue;
 66         dfs(to,u,s+1);
 67         vs[++tot]=u;
 68     }
 69 }
 70 void work_rmq()
 71 {
 72     for(int i=1;i<=tot;i++) dp[i][0]=vs[i];
 73     int up=log(tot)/log(2);
 74     for(int j=1;j<=up;j++)
 75     {
 76         int t=(1<<j)-1;
 77         for(int i=1;i+t<=tot;i++)
 78         {
 79             int a=dp[i][j-1],b=dp[i+(1<<(j-1))][j-1];
 80             if(d[a]<=d[b]) dp[i][j]=a;
 81             else dp[i][j]=b;
 82         }
 83     }
 84 }
 85 int get_lca(int u,int v)
 86 {
 87     int l,r;
 88     if(f[u]<f[v]) l=f[u],r=f[v];
 89     else l=f[v],r=f[u];
 90     int j=log(r-l+1)/log(2);
 91     int a=dp[l][j],b=dp[r-(1<<j)+1][j];
 92     if(d[a]<d[b]) return a;
 93     else return b;
 94 }
 95 int main()
 96 {
 97     read(n); read(m);
 98     for(int i=1;i<=n;i++)
 99         read(a[i]),fa[i]=i,hs[++top]=a[i];
100     for(int i=1;i<n;i++)
101     {
102         int f,t;
103         read(f); read(t);
104         e[f].push_back(t);
105         e[t].push_back(f);
106     }
107     sort(hs+1,hs+top+1);
108     top=unique(hs+1,hs+top+1)-hs-1;
109     dfs(1,0);
110     dfs(1,0,0);
111     work_rmq();
112     for(int i=1;i<=m;i++)
113     {
114         int u,v,k;
115         read(u); read(v); read(k);
116         int lca=get_lca(u,v);
117         int x=root[u],y=root[v];
118         int z=root[lca];
119         int t=fa[lca]==lca ? 0:root[fa[lca]];
120         int ans=seg.query(1,top,x,y,z,t,k);
121         printf("%d
",ans);
122     }
123     return 0;
124 }
125 /*
126 */
原文地址:https://www.cnblogs.com/CJLHY/p/8465829.html