Treap 模板 poj1442&hdu4557

原理可以看hihocoder上面的讲解,很清楚,不多说了。

模板抄lrj训练指南上面的。

/**
Treap 实现 名次树
功能: 1.找到排名为k的元素
       2.值为x的元素的名次

初始化:Node* root = NULL;
*/
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;

struct Node {
    Node * ch[2];   // 0左子树 1右子树
    int r;          // 随机优先级
    int v;          //
    int s;          // 以s为根的子树的大小
    Node(int v):v(v)
    {
        ch[0] = ch[1] = NULL;
        r = rand();
        s = 1;
    }
    bool operator<(const Node& rhs) const { // 按随机优先级排序
        return r < rhs.r;
    }
    int cmp(int x) const
    {
        if (x == v) return -1;
        return x < v ? 0 : 1;
    }
    void maintain() // 更新
    {
        s = 1;
        if (ch[0] != NULL) s += ch[0]->s;
        if (ch[1] != NULL) s += ch[1]->s;
    }
} ;

void rotate(Node* &o, int d) // d=0 代表左转, d=1代表右转
{
    Node* k = o->ch[d^1];
    o->ch[d^1] = k->ch[d];
    k->ch[d] = o;
    o->maintain();
    k->maintain();
    o = k;
}

void insert(Node* &o, int x) // o是根 x是插入的值
{
    if (o == NULL) {
        o = new Node(x);
    } else {
        int d = (x < o->v ? 0 : 1); // 不用cmp函数因为可能有重复的值
        insert(o->ch[d], x);
        if ((o->ch[d]->r) > (o->r)) rotate(o, d^1);
    }
    o->maintain();
}

void remove(Node* &o, int x)
{
    int d = o->cmp(x);
    if (d == -1) {
        Node* u = o;
        if (o->ch[0] != NULL && o->ch[1] != NULL) {
            int d2 = ((o->ch[0]->r) > (o->ch[1]->r) ? 1 : 0);
            rotate(o, d2);
            remove(o->ch[d2], x);
        } else {
            if (o->ch[0] == NULL) o = o->ch[1];
            else o = o->ch[0];
            delete u;
        }
    } else {
        remove(o->ch[d], x);
    }
    if (o != NULL) o->maintain();
}

int find(Node* o, int x) // 因为remove和insert都没有查值存不存在 记得操作之前调用find
{
    while (o != NULL) {
        int d = o->cmp(x);
        if (d == -1) return 1;
        else o = o->ch[d];
    }
    return 0;
}

int kth(Node* &o, int k, int fg) // fg=1第k大的值 fg=0第k小的值 返回0表示没找到
{
    if (o == NULL || k <= 0 || k > o->s) return 0;
    int s = (o->ch[fg] == NULL ? 0 : o->ch[fg]->s);
    if (k == s+1) return o->v;
    else if (k <= s) return kth(o->ch[fg], k, fg);
    else return kth(o->ch[fg^1], k-s-1, fg);
}

int prv(Node* &o, int x) // 查找x前面的元素 (<x的最大值
{
    if (o == NULL) return -1;
    if (x <= o->v) return prv(o->ch[0], x);
    int ans = prv(o->ch[1], x);

    return ans == -1 ? o->v : ans;
}

int nxt(Node* &o, int x) // 查找x后面的元素 (>x的最大值
{
    if (o == NULL) return -1;
    if (x >= o->v) return nxt(o->ch[1], x);
    int ans = nxt(o->ch[0], x);
    return ans == -1 ? o->v : ans;
}

void mergeto(Node* &src, Node* &dest) // 合并两棵树 把src加到dest上 src和dest都是树的根
{
    if (src->ch[0] != NULL) mergeto(src->ch[0], dest);
    if (src->ch[0] != NULL) mergeto(src->ch[1], dest);
    insert(dest, src->v);
    delete src;
    src = NULL;
}

void print(Node* &o)
{
    if (o == NULL) return ;
    print(o->ch[0]);
    printf("%d ", o->v);
    print(o->ch[1]);
}

int main()
{
    Node* root = NULL;
    insert(root, 3);
    insert(root, 3);
    insert(root, 4);
    insert(root, 5);
    remove(root, 4);
    print(root);
    return 0;
}

例题:

上面hihocoder的例题,这个代码是照着讲解自己写的

//Treap.cpp

#include <stdio.h>
#include <string.h>
#include <stdlib.h>

const int N = 100005;


struct Treap {
    int father, left, right;
    int key, weight;
    void init(int k, int w, int fa) {
        left = right = -1;
        father = fa, key = k, weight = w;
    }
} tp[N];
int root;
int treap_cnt;

int new_treap(int k, int w, int fa = -1)
{
    tp[treap_cnt].init(k, w, fa);
    return treap_cnt++;
}

void left_rotate(int a) // 左旋 把节点A的右儿子节点B转到A的父亲节点
{
    int b = tp[a].right;
    tp[b].father = tp[a].father;
    if (tp[tp[a].father].left == a) { // 判断a是父节点的左儿子还是右儿子 并用b替换
        tp[tp[a].father].left = b;
    } else {
        tp[tp[a].father].right = b;
    }
    tp[a].right = tp[b].left;
    if (tp[b].left != -1) tp[tp[b].left].father = a;
    tp[b].left = a;
    tp[a].father = b;
}

void right_rotate(int a) // 右旋 把节点A的左儿子节点B转到A的父亲节点
{
    int b = tp[a].left;
    tp[b].father = tp[a].father;
    if (tp[tp[a].father].left == a) tp[tp[a].father].left = b;
    else tp[tp[a].father].right = b;
    tp[a].left = tp[b].right;
    if (tp[b].right != -1) tp[tp[b].right].father = a;
    tp[b].right = a;
    tp[a].father = b;
}

int insert(int a, int key)
{
    if (key < tp[a].key) {
        if (tp[a].left == -1) {
            tp[a].left = new_treap(key, rand(), a);
            return tp[a].left;
        } else {
            return insert(tp[a].left, key);
        }
    } else {
        if (tp[a].right == -1) {
            tp[a].right = new_treap(key, rand(), a);
            return tp[a].right;
        } else {
            return insert(tp[a].right, key);
        }
    }
}

void rotate(int a) // 维持小顶堆
{
    int fa = tp[a].father;
    while (fa != -1) {
        if (tp[a].weight < tp[fa].weight) {
            if (a == tp[fa].left) right_rotate(fa);
            else left_rotate(fa);
            fa = tp[a].father;
        } else {
            break;
        }
    }
    if (fa == -1) root = a;
}

int find(int a, int key)
{
    int cur = a, pre = -1;
    while (cur != -1) {
        if (tp[cur].key > key) {
            pre = cur;
            cur = tp[cur].left;
        } else if (tp[cur].key < key) {
            pre = cur;
            cur = tp[cur].right;
        } else {
            return key;
        }
    }
    while (pre != -1) {
        if (tp[pre].key < key) return tp[pre].key;
        pre = tp[pre].father;
    }
    return -2;
}

void print(int a)
{
    if (a == -1) return;
    print(tp[a].left);
    printf("%d(%d) ", a, tp[a].key);
    print(tp[a].right);
}


int main(int argc, char const *argv[])
{
    //freopen("in", "r", stdin);
    root = -1;
    treap_cnt = 0;
    int n;
    char op[2];
    int k;
    scanf("%d", &n);
    while (n--) {
        scanf("%s%d", op, &k);
        if (*op == 'I') {
            if (root == -1) root = new_treap(k, rand());
            else rotate(insert(root, k));
        } else {
            printf("%d
", find(root, k));
        }
    }
    return 0;
}
View Code

poj1442,直接套模板,比较简单

/**
Treap 实现 名次树
功能: 1.找到排名为k的元素
       2.值为x的元素的名次
*/
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;

struct Node {
    Node * ch[2];   // 0左子树 1右子树
    int r;          // 随机优先级
    int v;          //
    int s;          // 以s为根的子树的大小
    Node(int v):v(v)
    {
        ch[0] = ch[1] = NULL;
        r = rand();
        s = 1;
    }
    bool operator<(const Node& rhs) const { // 按随机优先级排序
        return r < rhs.r;
    }
    int cmp(int x) const
    {
        if (x == v) return -1;
        return x < v ? 0 : 1;
    }
    void maintain() // 更新
    {
        s = 1;
        if (ch[0] != NULL) s += ch[0]->s;
        if (ch[1] != NULL) s += ch[1]->s;
    }
} ;

void rotate(Node* &o, int d) // d=0 代表左转, d=1代表右转
{
    Node* k = o->ch[d^1];
    o->ch[d^1] = k->ch[d];
    k->ch[d] = o;
    o->maintain();
    k->maintain();
    o = k;
}

void insert(Node* &o, int x) // o是根 x是插入的值
{
    if (o == NULL) {
        o = new Node(x);
    } else {
        int d = (x < o->v ? 0 : 1); // 不用cmp函数因为可能有重复的值
        insert(o->ch[d], x);
        if ((o->ch[d]->r) > (o->r)) rotate(o, d^1);
    }
    o->maintain();
}

void remove(Node* &o, int x)
{
    int d = o->cmp(x);
    if (d == -1) {
        Node* u = o;
        if (o->ch[0] != NULL && o->ch[1] != NULL) {
            int d2 = ((o->ch[0]->r) > (o->ch[1]->r) ? 1 : 0);
            rotate(o, d2);
            remove(o->ch[d2], x);
        } else {
            if (o->ch[0] == NULL) o = o->ch[1];
            else o = o->ch[0];
            delete u;
        }
    } else {
        remove(o->ch[d], x);
    }
    if (o != NULL) o->maintain();
}

int find(Node* &o, int x) // 因为remove和insert都没有查值存不存在 记得操作之前调用find
{
    while (o != NULL) {
        int d = o->cmp(x);
        if (d == -1) return 1; // exist
        else o = o->ch[d];
    }
    return 0;   // not exist
}

int kth(Node* &o, int k, int fg) // fg=1第k大的值 fg=0第k小的值 返回0表示没找到
{
    if (o == NULL || k <= 0 || k > o->s) return 0;
    int s = (o->ch[fg] == NULL ? 0 : o->ch[fg]->s);
    if (k == s+1) return o->v;
    else if (k <= s) return kth(o->ch[fg], k, fg);
    else return kth(o->ch[fg^1], k-s-1, fg);
}

void mergeto(Node* &src, Node* &dest)
{
    if (src->ch[0] != NULL) mergeto(src->ch[0], dest);
    if (src->ch[0] != NULL) mergeto(src->ch[1], dest);
    insert(dest, src->v);
    delete src;
    src = NULL;
}

const int N = 30005;
int a[N];
int main()
{
    //freopen("in", "r", stdin);
    int m, n;
    while (~scanf("%d%d", &m, &n)) {
        Node* root = NULL;
        for (int i = 1; i <= m; ++i) {
            scanf("%d", a+i);
        }
        int x = 0, u, cnt = 0;
        for (int i = 0; i < n; ++i) {
            scanf("%d", &u);
            while (cnt < u) {
                insert(root, a[++cnt]);
            }
            printf("%d
", kth(root, ++x, 0));
        }
    }
    return 0;
}
View Code

hdu4557

每个结点加了一个元素t,排序查找都要考虑t

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <map>
#include <vector>
using namespace std;

struct Node {
    Node * ch[2];   // 0左子树 1右子树
    int r;          // 随机优先级
    int v;          //
    int t;
    int s;          // 以s为根的子树的大小
    Node(int v, int t):v(v), t(t)
    {
        ch[0] = ch[1] = NULL;
        r = rand();
        s = 1;
    }
    bool operator<(const Node& rhs) const { // 按随机优先级排序
        if (r == rhs.r) return t < rhs.t;
        return r < rhs.r;
    }
    int cmp(int x, int y) const
    {
        if (x == v && y == t) return -1;
        if (x == v) return y < t ? 0 : 1;
        return x < v ? 0 : 1;
    }
    void maintain() // 更新
    {
        s = 1;
        if (ch[0] != NULL) s += ch[0]->s;
        if (ch[1] != NULL) s += ch[1]->s;
    }
} ;

void rotate(Node* &o, int d) // d=0 代表左转, d=1代表右转
{
    Node* k = o->ch[d^1];
    o->ch[d^1] = k->ch[d];
    k->ch[d] = o;
    o->maintain();
    k->maintain();
    o = k;
}

void insert(Node* &o, int x, int y) // o是根 x是插入的值
{
    if (o == NULL) {
        o = new Node(x, y);
    } else {
        int d = o->cmp(x, y);
        insert(o->ch[d], x, y);
        if ((o->ch[d]->r) > (o->r)) rotate(o, d^1);
    }
    o->maintain();
}

void remove(Node* &o, int x, int y)
{
    int d = o->cmp(x, y);
    if (d == -1) {
        Node* u = o;
        if (o->ch[0] != NULL && o->ch[1] != NULL) {
            int d2 = ((o->ch[0]->r) > (o->ch[1]->r) ? 1 : 0);
            rotate(o, d2);
            remove(o->ch[d2], x, y);
        } else {
            if (o->ch[0] == NULL) o = o->ch[1];
            else o = o->ch[0];
            delete u;
        }
    } else {
        remove(o->ch[d], x, y);
    }
    if (o != NULL) o->maintain();
}

int nxt(Node* &o, int x, int &res) // 查找x后面的元素 (>=x的最小值
{
    if (o == NULL) return -1;
    if (o->v < x) return nxt(o->ch[1], x, res);
    res = o->v;
    int ans = nxt(o->ch[0], x, res);
    return ans == -1 ? o->t : ans;
}


void print(Node* &o)
{
    if (o == NULL) return ;
    print(o->ch[0]);
    printf("%d ", o->v);
    print(o->ch[1]);
}

int main()
{
freopen("in", "r", stdin);
    int T;
    cin >> T;
    int cas = 0;
    while (T--) {
        printf("Case #%d:
", ++cas);
        int n;
        cin >> n;
        char op[10];
        string name;
        int abi, res;
        Node* root = NULL;
        map<int, string> mp;
        for (int i = 0; i < n; ++i) {
            scanf("%s", op);
            if (*op == 'A') {
                cin >> name;
                scanf("%d",  &abi);
                mp[i] = name;
                insert(root, abi, i);
                printf("%d
", root->s);
            } else {
                scanf("%d", &abi);
                int ans = nxt(root, abi, res);
                if (ans == -1) printf("WAIT...
");
                else {
                    cout << mp[ans] << endl;
                    remove(root, res, ans);
                }
            }//printf("debug:"); print(root); printf("
");
        }
    }
}
View Code



原文地址:https://www.cnblogs.com/wenruo/p/5762899.html