SPOJ COT Count on a tree

先离散,在树上建主席树,主席树具有加减性。

那么u和v的路径上的线段树 = Tree(u) + Tree(v) - Tree(lca(u,v))- Tree(fa(lca(u,v)))。

不懂的画个图就清楚了。

把建可持续化的部分改成迭代了,效率更高了。

/*********************************************************
*            ------------------                          *
*   author AbyssFish                                     *
**********************************************************/
#include<cstdio>
#include<iostream>
#include<string>
#include<cstring>
#include<queue>
#include<vector>
#include<stack>
#include<map>
#include<set>
#include<algorithm>
#include<cmath>
#include<numeric>
#include<climits>
using namespace std;

const int MAX_N = 1e5+5;
const int MAXD = 18;//ceil(log2(N))+1


int hd[MAX_N], to[MAX_N<<1], nx[MAX_N<<1], ec;
void add_edge(int u,int v)
{
    to[ec] = v;
    nx[ec] = hd[u];
    hd[u] = ec++;
}
void init_g(int n){ memset(hd+1,0xff,4*n); ec = 0; }
#define eachedge int i = hd[u]; ~i; i = nx[i]
#define ifvalid int v = to[i]; if(v == f) continue;


int we[MAX_N];
int wes[MAX_N];
int mpw[MAX_N];
int rw[MAX_N];

int *c_cmp;
bool cmp_id(int i,int j){ return c_cmp[i] < c_cmp[j]; }

int compress(int n, int *a, int *r, int *b, int *mp)
{
    for(int i = 0; i < n; i++){
        r[i] = i;
    }
    c_cmp = a;
    sort(r,r+n,cmp_id);
    int k = 1;
    mp[b[r[0]] = 1] = a[r[0]];
    for(int i = 1; i < n; i++){
        int j = r[i];
        if(a[j] != a[r[i-1]]){
            mp[ b[j] = ++k ] = a[j];
        }
        else {
            b[j] = k;
        }
    }
    return k;
}

/*------------------------------------------*/

int dep[MAX_N];
int path[MAX_N<<1];
int pid[MAX_N];
int dfs_clk;

void get_path(int u,int f,int d)
{
    dep[u] = d;
    path[++dfs_clk] = u;
    pid[u] = dfs_clk;
    for(eachedge){
        ifvalid
        get_path(v,u,d+1);
        path[++dfs_clk] = u;
    }
}

struct SparseTable
{
    int mxk[MAX_N<<1];
    int d[MAX_N<<1][MAXD]; //结点作为下标 floor(log2(N*2))
    void init(int *mp,int *r, int n)
    {
        mxk[0] = -1;
        for(int i = 1; i <= n; i++){
            d[i][0] = r[i];
            mxk[i] = ((i&(i-1)) == 0) ?mxk[i-1]+1:mxk[i-1];   //i&(i-1) == i - lower_bit(i)
        }
        c_cmp = mp;
        for(int j = 1; j <= mxk[n]; j++){
            int t = (1<<j)-1, s = 1<<(j-1);
            for(int i = 1; i + t <= n; i++){
                d[i][j] = min(d[i][j-1],d[i+s][j-1],cmp_id);
            }
        }


    }
    //[ ) , l < r
    int RMQ(int l,int r)
    {
        //c_cmp = mp;
        int k = mxk[r-l];
        return min(d[l][k],d[r-(1<<k)][k],cmp_id);
    }
}rmq;

void lca_init(int u)
{
    dfs_clk = 0;
    get_path(u,0,1);
    rmq.init(dep,path,dfs_clk);
}

int q_lca(int u, int v)
{
    if(pid[u] > pid[v]) swap(u,v);
    return rmq.RMQ(pid[u],pid[v]+1);
}

/*------------------------------------------*/

#define Tvar int md = (l+r)>>1;

int ns;

struct Node
{
    Node *lc,*rc;
    int s;
}meo[MAXD*MAX_N], *root[MAX_N];

Node * const nil =meo;
Node * freeNode = nil;


void build(int w,Node **o,int l = 1,int r = ns)
{
    while(l < r){
        *(++freeNode) = **o;
        *o = freeNode;
        (*o)->s++;
        int md = (l+r)>>1;
        if(w <= md){
            o = &((*o)->lc);
            r = md;
        }
        else {
            o = &((*o)->rc);
            l = md+1;
        }
    }
    *(++freeNode) = **o;
    *o = freeNode;
    (*o)->s++;
}


void dfs_build(int u,int f = 0)
{
    root[u] = root[f];
    build(wes[u],root+u);
    for(eachedge){
        ifvalid
        dfs_build(v,u);
    }
}


void q_kth()
{
    int u, v, k;
    scanf("%d%d%d",&u,&v,&k);
    int lca = q_lca(u,v);
    Node *x = root[u], *y = root[v], *z = root[lca];
    int l = 1, r = ns, p = wes[lca];
    while(l < r){
        Tvar
        int s = x->lc->s + y->lc->s - (z->lc->s<<1) + (l<=p && p<=md);

        if(k <= s){
            x = x->lc; y = y->lc; z = z->lc;
            r = md;
        }
        else {
            k -= s;
            x = x->rc; y = y->rc; z = z->rc;
            l = md+1;
        }
    }
    printf("%d
",mpw[l]);
}


//#define LOCAL
int main()
{
#ifdef LOCAL
    freopen("in.txt","r",stdin);
#endif
    //cout<<log2(MAX_N);
    //cout<<floor(log2(MAX_N<<1));
    //cout<<(1<<(MAXD-1))<<' '<<(MAX_N<<1);
    *nil = {nil,nil,0};
    root[0] = nil;
    int N, M;
    scanf("%d%d",&N,&M);
    init_g(N);
    for(int i = 1; i <= N; i++){
        scanf("%d",we+i);
    }
    for(int i = 1; i < N; i++){
        int u, v; scanf("%d%d",&u,&v);
        add_edge(u,v);
        add_edge(v,u);
    }
    ns = compress(N,we+1,rw,wes+1,mpw);
    lca_init(1);
    dfs_build(1);
    while(M--){
        q_kth();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/jerryRey/p/5038682.html