HDU 4747 Mex 线段树

题意:

(S)为一个自然数集合,定义函数(mex(S))为集合中没有出现的最小自然数。
给出一个长度为为(n)序列(a),设(S_{l,r})表示由(a_l sim a_r)构成的集合。
求:

[sumlimits_{1 leq l leq r leq n}mex(S_{l,r}) ]

分析:

有这样一个事实:往集合(S)中任意加入一个元素,(mex(S))的值不会变小。
固定区间左端点来统计答案。
首先计算一下(mex{S_{1,1}},mex{S_{1,2}}, cdots, mex{S_{1,n}}),所以这是一个非递减的序列。
假设现在计算出(mex{S_{i,i}},mex{S_{i,i+1}}, cdots, mex{S_{i,n}}),考虑区间左端点向右移动。
相当于从这些集合中都删去了一个(a_i),如果有一个最小的(j>i)(a_i=a_j),那么删去(a_i)([j,n])这段区间没有影响,因为这段区间对应的集合没有改变。
然后考虑区间([i+1,j-1]),找到(mex)值大于(a_i)的区间,把它们的值都变为(a_i)
因为集合中少了(a_i),所以根据(mex)函数的定义,(mex)值为(a_i)
而且由于区间是非递减的,所以(mex)值大于(a_i)的区间也是连续的,用线段树维护即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long LL;
const int maxn = 200000 + 10;
const int maxnode = maxn * 4;

int n;
int a[maxn], b[maxn], tot;
int pos[maxn], nxt[maxn];

bool vis[maxn];
int mex[maxn];

//Segment Tree
LL sum[maxnode];
int setv[maxnode], minv[maxnode], maxv[maxnode];

void pushup(int o) {
    sum[o] = sum[o<<1] + sum[o<<1|1];
    minv[o] = min(minv[o<<1], minv[o<<1|1]);
    maxv[o] = max(maxv[o<<1], maxv[o<<1|1]);
}

void build(int o, int L, int R) {
    if(L == R) {
        sum[o] = minv[o] = maxv[o] = mex[L];
        return;
    }
    int M = (L + R) / 2;
    build(o<<1, L, M);
    build(o<<1|1, M+1, R);
    pushup(o);
}

void pushdown(int o, int L, int R) {
    if(setv[o] != -1) {
        int lc = o<<1, rc = o<<1|1;
        setv[lc] = setv[rc] = setv[o];
        minv[lc] = minv[rc] = setv[o];
        maxv[lc] = maxv[rc] = setv[o];
        int M = (L + R) / 2;
        sum[lc] = (LL)setv[o] * (M - L + 1);
        sum[rc] = (LL)setv[o] * (R - M);
        setv[o] = -1;
    }
}

void update(int o, int L, int R, int qL, int qR, int v) {
    if(qL <= L && R <= qR && minv[o] > v) {
        setv[o] = minv[o] = maxv[o] = v;
        sum[o] = (LL)v * (R - L + 1);
        return;
    }
    pushdown(o, L, R);
    int M = (L + R) / 2;
    if(qL <= M && maxv[o<<1] > v) update(o<<1, L, M, qL, qR, v);
    if(qR > M && maxv[o<<1|1] > v) update(o<<1|1, M+1, R, qL, qR, v);
    pushup(o);
}

LL query(int o, int L, int R, int qL, int qR) {
    if(qL <= L && R <= qR) return sum[o];
    pushdown(o, L, R);
    int M = (L + R) / 2;
    LL ans = 0;
    if(qL <= M) ans += query(o<<1, L, M, qL, qR);
    if(qR > M) ans += query(o<<1|1, M+1, R, qL, qR);
    return ans;
}

int main()
{
    while(scanf("%d", &n) == 1 && n) {
        for(int i = 1; i <= n; i++) {
            scanf("%d", a + i);
            if(a[i] >= maxn) a[i] = maxn - 1;
            b[i] = a[i];
        }
        sort(b + 1, b + 1 + n);
        tot = unique(b + 1, b + 1 + n) - b - 1;
        for(int i = 1; i <= n; i++)
            a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
        for(int i = 1; i <= tot; i++) pos[i] = n + 1;
        for(int i = n; i > 0; i--) {
            nxt[i] = pos[a[i]];
            pos[a[i]] = i;
        }

        memset(vis, false, sizeof(vis));
        int p = 0;
        for(int i = 1; i <= n; i++) {
            vis[b[a[i]]] = true;
            while(vis[p]) p++;
            mex[i] = p;
        }

        memset(setv, -1, sizeof(setv));
        build(1, 1, n);
        LL ans = sum[1];
        for(int i = 2; i <= n; i++) {
            int j = nxt[i - 1];
            if(j > i) update(1, 1, n, i, j - 1, b[a[i-1]]);
            ans += query(1, 1, n, i, n);
        }

        printf("%lld
", ans);
    }

    return 0;
}
原文地址:https://www.cnblogs.com/AOQNRMGYXLMV/p/5343903.html