XTU1170 Coin 线段树

题意:存在一个固定编号[1,10^9]的盒子群,每个盒子里面都有一个数Ai(Ai>=0),刚开始的时候不知道各个盒子里面的数是多大。有Q组更新,分别表示[Li, Ri]内最小的元素是多大,现在问Q组操作后,盒子中数字总和最小可以为多少?

解法:标准题解是并查集树状数组的解法,谢勇教练也提到可以使用线段树来解。这里就是使用的线段树来求解。由于该题需要离散化点,这里有个地方特别要注意就是不能够建一棵点树,因为如果建立了一棵点树,那么意味着离散化后的叶子区间是单个点,各个相邻的叶子节点之间的区间信息无法保存,也就是这里错了很多次。正确的做法是如果有某组更新是[Li, Ri] Ci,那么就是更新[Li, Ri+1) Ci,离散化所有的Li 和 Ri+1,而不去管Ri的值,这样区间树中叶子节点和叶子节点之间就能够通过共一个节点来解决不连续的问题。注意建立区间树或者点树这个概念是相对于离散化之前如何构造区间的定义,如果离散化之前使用点树模型,离散化之后使用区间树模型也是无用的。具体做法是将所有的更新保留起来然后对坐标点离散化,然后按照先大后小的原则对区间进行更新,如果一个覆盖没有更新节点那么返回false,最后统计总和即可。

线段树的每一个节点中保留一个变量val。
val = -1表示该节点所统计的区间还没有被更新;
val = -2表示该节点统计的区间中同时存在被更新或者未被更新的区间;
val > 0 表示该节点统计的区间被更新为最少是val。

代码如下:

#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <map>
using namespace std;

const int MAXN = int(1e9);
typedef long long LL;

struct QQ{
    int l, r, val;
    friend bool operator < (const QQ & a, const QQ & b) {
        return a.val > b.val;
    }
}q[100005];
int val[200005];
map<int,int>mp;

struct Node {
    int l, r, tag;
}e[800005];

void build(int p, int l, int r) {
    e[p].l = l, e[p].r = r, e[p].tag = -1;
    if (r - l > 1) {
        int mid = (l + r) >> 1;
        build(p<<1, l, mid);
        build(p<<1|1, mid, r);
    }
}

void push_up(int p) {
    if (e[p<<1].tag == -2 || e[p<<1|1].tag == -2) {
        e[p].tag = -2;    
    } else if (e[p<<1].tag == e[p<<1|1].tag) {
        e[p].tag = e[p<<1].tag;
    } else {
        e[p].tag = -2;    
    }
}

bool modify(int p, int l, int r, int val) {    
//    printf("L = %d, R = %d, val = %d\n", e[p].l, e[p].r, val);
    bool ret;
    if (l == e[p].l && r == e[p].r && e[p].tag != -2) {
        if (e[p].tag == -1 || val == e[p].tag) {
            ret = true;
            e[p].tag = val;
        } else {
            ret = false;    
        }
    } else {
        if (e[p].tag > -1) {
            if (e[p].tag != val) {
                ret = false;
            } else {
                ret = true;    
            }
        } else {
            int mid = (e[p].l + e[p].r) / 2;
            if (r <= mid) {
                ret = modify(p<<1, l, r, val);
            } else if (l >= mid) {
                ret = modify(p<<1|1, l, r, val);
            } else {
                ret = modify(p<<1, l, mid, val) | modify(p<<1|1, mid, r, val);    
            }
        }
    }
    if (ret) {
        push_up(p>>1);    
    }
    return ret;
}

LL query(int p) {
//    printf("L = %d, R = %d, tag = %d\n", e[p].l, e[p].r, e[p].tag);
    if (e[p].r - e[p].l == 1) { // 递归到叶子节点的时候退出
        if (e[p].tag > -1) {
            return 1LL * (val[e[p].r] - val[e[p].l]) * e[p].tag;
        } else {
            return 0;    
        }
    } else {
        if (e[p].tag > -1) {
            return 1LL * (val[e[p].r] - val[e[p].l]) * e[p].tag;
        } else if (e[p].tag == -2){
            return query(p<<1) + query(p<<1|1);
        } else {
            return 0;
        }
    }
}

int main() {
    int T, N, cnt;
    scanf("%d", &T);
    while (T--) {
        mp.clear();
        int flag = true;
        cnt = 1;
        scanf("%d", &N);
        for (int i = 0; i < N; ++i) {
            scanf("%d %d %d", &q[i].l, &q[i].r, &q[i].val);
            val[cnt++] = q[i].l, val[cnt++] = q[i].r + 1;
        }
        sort(q, q + N);
        val[0] = 1;
        sort(val, val + cnt);
        cnt = unique(val, val + cnt) - val;
        val[cnt++] = MAXN;
        for (int i = 0; i < cnt; ++i) {
        //    printf("val[%d] = %d\n", i, val[i]);
            mp[val[i]] = i;
        }
        build(1, 0, cnt-1);
        for (int i = 0; i < N; ++i) {
        //    printf("LL = %d, RR = %d\n", mp[q[i].l], mp[q[i].r + 1]);
            if (!modify(1, mp[q[i].l], mp[q[i].r+1], q[i].val)) {
                flag = false;
                break;
            }
        }
        if (!flag) {
            puts("Error");
        } else {
            printf("%I64d\n", query(1));
        }
    }
    return 0;    
}

以下是一种使用并查集的写法,非常巧妙的运用到了并查集的思想,将连续的线段依次指向右边的线段,这样在更新的时候能够直接跳过重复的地方。使用树状数组统计一次区间覆盖上有没有更新。

代码如下:

#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <map>
using std::map;
using std::sort;
using std::unique;

const int MAXN = int(1e9);
int LIM;
map<int,int>mp;
typedef long long LL;

struct QQ{
    int a, b, c;
    friend bool operator < (const QQ & a, const QQ & b) {
        return a.c > b.c; // 按照价格从大到小进行排列
    }
}q[100005];

int val[200005];
int st[200005];
int BIT[200005];
int sz[200005];
int list[200005];

int find(int x) {
    return st[x] = x == st[x] ? x : find(st[x]);
}

void merge(int x, int y) {
    st[x] = y;
}

int lowbit(int x) {
    return x & (-x);
}

int ADD(int x, int val) {
    for (int i = x; i <= LIM; i += lowbit(i)) {
        BIT[i] += val;
    }
}

int SUM(int x) {
    int ret = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
        ret += BIT[i];
    }
    return ret;
}

int main() {
    int T, Q;
    scanf("%d", &T);
    while (T--) {
        mp.clear();
        LL ret = 0;
        LIM = 0;
        bool flag = true;
        scanf("%d", &Q);
        for (int i = 0; i < Q; ++i) {
            scanf("%d %d %d", &q[i].a, &q[i].b, &q[i].c);
            val[LIM++] = q[i].a, val[LIM++] = ++q[i].b;
        }
        sort(val, val+LIM);
        LIM = unique(val, val+LIM)-val;
        for (int i = 1; i <= LIM; ++i) {
            mp[val[i-1]] = i;
            sz[i] = val[i]-val[i-1]; // 求出一条条线段
            st[i] = i;
            BIT[i] = 0;
        }
        sort(q, q+Q);
        int cnt, j, a, b;
        for (int i = 0; i < Q && flag; i = j) {
            cnt = 0, j = i+1;
            while (j < Q && (q[j].c == q[j-1].c)) ++j; // 将相同的值统一处理
            for (int k = i; k < j; ++k) {
                a = mp[q[k].a], b = mp[q[k].b];
                for (int t = find(a); t < b; t = find(t+1)) {
                    list[cnt++] = t;
                    st[t] = find(t+1);
                    ret += 1LL * q[i].c * sz[t];
                }
            } 
            for (int k = 0; k < cnt; ++k) {
                ADD(list[k], 1);    
            }
            for (int k = i; k < j; ++k) {
                printf("");
                if (SUM(b) == SUM(a-1)) {
                    flag = false;
                    break;
                }
            }
            for (int k = 0; k < cnt; ++k) {
                ADD(list[k], -1);
            }
        }
        if (!flag) {
            puts("Error");
        } else {
            printf("%I64d\n", ret);
        }
    }
    return 0;
} 
原文地址:https://www.cnblogs.com/Lyush/p/3079214.html