POJ 2010 Moo University

按第一关键字排序后枚举中位数,就变成了判断“左边前K小的和 + 这个中位数 + 右边前K小的和 <= F",其中维护前K小和可以用treap做到。

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <vector>
#include <algorithm>
using namespace std;

struct node
{
    node *ch[2];
    int sz;
    int v;
    int r;
    int sum;
    node(int v = 0) : v(v)
    {
        r = rand();
        sz = 1;
        ch[0] = ch[1] = NULL;
        sum = v;
    }
    int cmp(int k)
    {
        return k > v;
    }
    void maintain()
    {
        sz = 1;
        sum = v;
        if(ch[0] != NULL)
        {
            sz += ch[0]->sz;
            sum += ch[0]->sum;
        }
        if(ch[1] != NULL)
        {
            sz += ch[1]->sz;
            sum += ch[1]->sum;
        }
    }
};

void rotate(node *&o, int d)
{
    node *k = o->ch[d^1];
    o->ch[d^1] = k->ch[d];
    k->ch[d] = o;
    o->maintain();
    k->maintain();
    o = k;
}

void insert(node *&o, int v)
{
    if(o == NULL) o = new node(v);
    else
    {
        int d = o->cmp(v);
        insert(o->ch[d], v);
        if(o->ch[d]->r > o->r) rotate(o, d^1);
    }
    o->maintain();
}

void remove(node *&o, int v)
{
    if(v == o->v)
    {
        if(o->ch[0] == NULL) o = o->ch[1];
        else if(o->ch[1] == NULL) o = o->ch[0];
        else
        {
            int d = o->ch[0]->r < o->ch[1]->r ? 0 : 1;
            rotate(o, d);
            remove(o->ch[d], v);
        }
    } else
    {
        int d = o->cmp(v);
        remove(o->ch[d], v);
    }
    if(o != NULL) o->maintain();
}

int query(node *o, int k)
{
    int s = o->ch[0] == NULL ? 0 : o->ch[0]->sz;
    int sum = o->ch[0] == NULL ? 0 : o->ch[0]->sum;
    if(s + 1 > k) return query(o->ch[0], k);
    if(s + 1 == k) return sum + o->v;
    if(s + 1 < k) return sum + o->v + query(o->ch[1], k - s - 1);
}

void del(node *o)
{
    if(o->ch[0] != NULL) del(o->ch[0]);
    if(o->ch[1] != NULL) del(o->ch[1]);
    delete o;
}


int main()
{
    int n, c, f;
    while(scanf("%d%d%d", &n, &c, &f) != EOF)
    {
        vector<pair<int, int> > a;
        for(int i = 0; i < c; i++)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            a.push_back(make_pair(x, y));
        }

        sort(a.begin(), a.end());

        node *r = NULL, *l = NULL;
        for(int i = 0; i < n / 2; i++)
            insert(r, a[c - i - 1].second);
        for(int i = c - n / 2 - 1; i >= 0; i--)
            insert(l, a[i].second);
        int ans = -1;
        for(int i = c - n / 2 - 1; i >= n / 2; i--)
        {
            remove(l, a[i].second);
            if(query(l, n / 2) + query(r, n / 2) + a[i].second <= f)
            {
                ans = a[i].first;
                break;
            }
            insert(r, a[i].second);
        }

        printf("%d
", ans);
        del(l); del(r);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/BMan/p/3843497.html