【POJ2104】K-th Number

【POJ2104】K-th Number

题面

virtual judge

题解

其实就是一道主席树(sb)

但是为了学习整体二分的需要就用整体二分写了。。。

所以主要利用此题讲一下整体二分到底是个啥(以下部分参考李煜东《算法竞赛进阶指南》):

两个例子

(Eg1)

给定一个正整数序列(A)及固定的整数(S),执行(M)此操作

每次查询(l)~(r)间不大于(S)的数或将(A[x])改为(y)

很简单吧。。。

用树状数组维护一下就好了吧。。。

(Eg2)

给定一个正整数序列,求此序列的第(K)小数是多少。

看到这里也许你觉得我是个傻逼。。。直接排一遍序就好了啊

但是,为了引入整体二分使问题复杂化,我们采用第二种方法:

二分答案,设当前二分值为(mid),统计有多少个数(leq mid),记为(cnt)

1.若(K leq cnt),则说明K小数值一定(in)([l, mid]),可在左半区间继续二分

2.若(K > cnt),则最小数一定(in)([mid+1, r]),等价于在值域([mid+1,r])下寻找(K-cnt)小的数

复杂度(N) (logSIZE)

回到(POJ2104),要求(M)个形如“求序列(A)(l)(r)个数中第(k)小的数”,这样做(M)次显然是不行的

而这样做(M)次中会有大量冗余状态,于是就有了---整体二分

整体二分

对于此题,

我们套用(Eg2)的做法

尝试在序列(A)中值域([MINA,MAXA])二分答案(mid)

记区间(l_i)(r_i)中小于等于(mid)的数有(c_i)

然后将这些询问分类:

1.若(k_i) (leq) (c_i),则说明第(i)个询问的答案在([MINA,mid])

2.若(k_i>c_i),则说明第(i)个询问的答案在([mid+1,MAXA])中,且等价于在值域([mid+1,MAXA])中查询第(k_i-c_i)小的数

然后分别把上面两类分为子序列(LA)(RA),分开处理即可

对于统计(c_i)可以利用(Eg1)中的树状数组维护

具体实现还是看代码吧,感觉越讲越懵啊。。。

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;

inline int gi() {
    register int data = 0, w = 1;
    register char ch = 0;
    while (ch != '-' && (ch > '9' || ch < '0')) ch = getchar();
    if (ch == '-') w = -1 , ch = getchar();
    while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
    return w * data;
}
const int MAX_N = 200005, INF = 1e9; 
struct rec {int op, x, y, z; } q[MAX_N << 1], lq[MAX_N << 1], rq[MAX_N << 1]; 
int N, M, tot, c[MAX_N], ans[MAX_N]; 
inline int lb(int x) { return x & -x; } 
void add(int x, int v) { while (x <= N) c[x] += v, x += lb(x); } 
int sum(int x) { int res = 0; while (x > 0) res += c[x], x -= lb(x); return res; } 
int X[MAX_N], cnt; 
void Div(int lval, int rval, int st, int ed) { 
    if (st > ed) return ; 
    if (lval == rval) {  
        for (int i = st; i <= ed; i++) 
            if (q[i].op > 0) ans[q[i].op] = lval; 
        return ; 
    } 
    int mid = (lval + rval) >> 1; 
    int lt = 0, rt = 0; 
    for (int i = st; i <= ed; i++) { 
    	if (q[i].op == 0) { 
    	    if (q[i].y <= mid) add(q[i].x, 1), lq[++lt] = q[i]; 
    	    else rq[++rt] = q[i]; 
        } else { 
            int res = sum(q[i].y) - sum(q[i].x - 1); 
            if (res >= q[i].z) lq[++lt] = q[i]; 
            else q[i].z -= res, rq[++rt] = q[i]; 
        } 
    } 
    for (int i = st; i <= ed; i++) { 
        if (q[i].op == 0 && q[i].y <= mid) add(q[i].x, -1); 
    } 
    for (int i = 1; i <= lt; i++) q[st + i - 1] = lq[i]; 
    for (int i = 1; i <= rt; i++) q[st + lt + i - 1] = rq[i]; 
    Div(lval, mid, st, st + lt - 1); 
    Div(mid + 1, rval, st + lt, ed); 
} 
int main () { 
    N = gi(), M = gi(); 
    for (int i = 1; i <= N; i++) { 
        int v = gi(); 
        q[++tot].op = 0, q[tot].x = i, X[++cnt] = q[tot].y = v; 
    } 
    sort(&X[1], &X[cnt + 1]); cnt = unique(&X[1], &X[cnt + 1]) - X - 1; 
    for (int i = 1; i <= N; i++) q[i].y = lower_bound(&X[1], &X[cnt + 1], q[i].y) - X; 
    for (int i = 1; i <= M; i++) { 
        q[++tot].op = i, q[tot].x = gi(), q[tot].y = gi(), q[tot].z = gi(); 
    } 
    Div(1, N, 1, tot); 
    for (int i = 1; i <= M; i++) printf("%d
", X[ans[i]]); 
    return 0; 
} 

另附主席树代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;

inline int gi() {
    register int data = 0, w = 1;
    register char ch = 0;
    while (ch != '-' && (ch > '9' || ch < '0')) ch = getchar();
    if (ch == '-') w = -1 , ch = getchar();
    while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
    return w * data;
}
#define MAX_N 200005
struct Node {
    int ls, rs, val; 
} t[MAX_N << 5];
int cnt = 0, rt[MAX_N << 5];
void build(int &o, int l, int r) {
    o = ++cnt;
    if (l == r) return ;
    int mid = (l + r) >> 1; 
    build(t[o].ls, l, mid); 
    build(t[o].rs, mid + 1, r); 
} 
void insert(int &o, int pre, int l, int r, int x) { 
    o = ++cnt;
    t[o].ls = t[pre].ls; t[o].rs = t[pre].rs; t[o].val = t[pre].val; 
    t[o].val++; 
    if (l == r) return ; 
    int mid = (l + r) >> 1; 
    if (x <= mid) insert(t[o].ls, t[pre].ls, l, mid, x);
    else insert(t[o].rs, t[pre].rs, mid + 1, r, x); 
}
int query(int u, int v, int l, int r, int k) {
    if (l == r) return l;
    int sz = t[t[u].ls].val - t[t[v].ls].val;
    int mid = (l + r) >> 1; 
    if (sz < k) return query(t[u].rs, t[v].rs, mid + 1, r, k - sz);
    else return query(t[u].ls, t[v].ls, l, mid, k); 
}
int N, M, a[MAX_N];
int X[MAX_N]; 
int main () {
    N = gi(), M = gi();
    for (int i = 1; i <= N; i++) X[i] = a[i] = gi();
    sort(&X[1], &X[N + 1]);
    int size = unique(&X[1], &X[N + 1]) - X - 1; 
    build(rt[0], 1, N); 
    for (int i = 1; i <= N; i++) {
        int x = lower_bound(&X[1], &X[size + 1], a[i]) - X;
        insert(rt[i], rt[i - 1], 1, N, x); 
    }
    while (M--) {
        int l = gi(), r = gi(), v = gi();
        printf("%d
", X[query(rt[r], rt[l - 1], 1, N, v)]); 
    }
    return 0; 
} 

原文地址:https://www.cnblogs.com/heyujun/p/10121892.html