SegmentTreeBeats 简单学习笔记

SegmentTreeBeats 简单学习笔记

​ 有一天补 ( ext{CF}) 做到一个题,转化一波题意以后变成要求维护一个序列 (a)

  1. 对于 (i in [l,r], a_i =a_i+x)

  2. 对于 (i in [l,r], a_i =min(a_i, x))

  3. (sum_{i=l}^r a_i)

    ​ 其实就是 ( ext{Segment Tree Beats}) 的模板题,也就是那年吉老师营员交流课件的例题, 用线段树维护区间最大值 (mx) ,区间次大值 (se) ,区间和 (sum) ,区间最大值出现次数 (cnt) ,加法标记 (tag)

    ​ 对于第二种操作,如果一个区间 (mx leq x) 那么无事发生,可以跳过其所有子区间。如果 (se < x < mx) ,那么 (sum = sum - (mx-x) imes cnt, mx = x) ,注意这里子区间的 (mx, sum) 并没有更改,相当于 (mx) 同时作为一个修改标记,当前区间比子区间的 (mx) 小时,要进行 (sum = sum - (mx-mx[fa]) imes cnt)( ext{pushdown}) 操作,打完标记之后就可以跳过了。对于其它情况,暴力对其子区间求解。

    ​ 到这一步位置算法流程不难理解,但算法的复杂度证明比较难懂,目前 (mathcal O(nlog n)) 的证明我还不会,只能理解 (mathcal O(n log^2 n)) 的证明,在这里写一下简要证明:

    ​ 定义势能函数 (Phi) 为线段树中 (mx) 不等于其父亲节点 (mx) 的节点数量,考虑一次第二操作过程的任一终止节点 (v) 。如果 (v)(Phi) 有贡献,假设这一类节点的数量为 (A) ,到达这些节点的复杂度为 (mathcal O (Alog n)) ,结束后这些节点都对势能没贡献了,也就是说用了 (mathcal O(Alog n)) 的时间让势能减小了 (A)

    ​ 如果 (v)(Phi) 没贡献,记 (u)(v) 的父亲,(u) 的另外一儿子为 (c) ,那么 (mx[u] = mx[v], se[u] eq se[v]) ,也就是说 (se[u] = mx[c]) 。那么 (c) 的子树一定会被访问, 并在访问结束后 (c)(Phi) 没有贡献,假设这一类节点数量为 (A) ,同样也用 (mathcal O(Alog n)) 的时间让势能减小了 (A) 。也就是说对于修改操作,实际上是每减小一个势能用了 (mathcal O(log n)) 的代价。

    ​ 考虑修改操作,每次只会修改 (mathcal O(log n)) 节点,最多使势能增加 (mathcal O(log n)) 所以总复杂度是 (mathcal O(nlog^2 n))

    code: Codeforces 1290 E

    /*program by mangoyang*/
    #pragma GCC optimize("Ofast", "inline")
    #include<bits/stdc++.h>
    #define inf (0x7f7f7f7f)
    #define Max(a, b) ((a) > (b) ? (a) : (b))
    #define Min(a, b) ((a) < (b) ? (a) : (b))
    typedef long long ll;
    using namespace std;
    template <class T>
    inline void read(T &x){
        int ch = 0, f = 0; x = 0;
        for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
        for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
        if(f) x = -x;
    }
    #define int ll
    const int N = 150005;
    int a[N], b[N], ans[N], n;
    namespace Seg{
        #define lson (u << 1)
        #define rson (u << 1 | 1)
        #define mid ((l + r) >> 1)
        int mx[N<<2], se[N<<2], sz[N<<2], cnt[N<<2], sum[N<<2], tag[N<<2];
        inline void clear(){
            memset(mx, 0, sizeof(mx));
            memset(se, 0, sizeof(se));
            memset(sz, 0, sizeof(sz));
            memset(cnt, 0, sizeof(cnt));
            memset(sum, 0, sizeof(sum));
            memset(tag, 0, sizeof(tag));
        }
        inline void update(int u){
            if(mx[lson] > mx[rson])
                mx[u] = mx[lson], cnt[u] = cnt[lson];
            else
                mx[u] = mx[rson], cnt[u] = cnt[rson];
            if(mx[lson] == mx[rson]) cnt[u] += cnt[lson];
            se[u] = max(se[lson], se[rson]);
            if(mx[lson] != mx[rson]){
                int x = min(mx[lson], mx[rson]);
                se[u] = max(se[u], x);
            }
            sum[u] = sum[lson] + sum[rson];
            sz[u] = sz[lson] + sz[rson];
        }
        inline void pushdown(int u){
            if(tag[u]){
                if(mx[lson]) mx[lson] += tag[u];
                if(se[lson]) se[lson] += tag[u];
                if(mx[rson]) mx[rson] += tag[u];
                if(se[rson]) se[rson] += tag[u];
                sum[lson] += tag[u] * sz[lson];
                sum[rson] += tag[u] * sz[rson];
                tag[lson] += tag[u];
                tag[rson] += tag[u];
                tag[u] = 0;
            }
            if(mx[lson] > mx[u]){
                sum[lson] -= (mx[lson] - mx[u]) * cnt[lson];
                mx[lson] = mx[u];
            }
            if(mx[rson] > mx[u]){
                sum[rson] -= (mx[rson] - mx[u]) * cnt[rson];
                mx[rson] = mx[u];
            }
        }
        inline void ins(int u, int l, int r, int pos, int x){
            if(l == r){
                mx[u] = sum[u] = x;
                sz[u] = cnt[u] = 1;
                return;
            }
            pushdown(u);
            if(pos <= mid) ins(lson, l, mid, pos, x);
            else ins(rson, mid + 1, r, pos, x);
            update(u);
        }
        inline void gao(int u, int l, int r, int L, int R, int x){
            if(l >= L && r <= R){
                if(mx[u] <= x) return;
                if(se[u] < x){
                    sum[u] -= (mx[u] - x) * cnt[u];
                    mx[u] = x;
                    return;
                }
                pushdown(u);
                gao(lson, l, mid, L, R, x);
                gao(rson, mid + 1, r, L, R, x);
                update(u);
                return;
            }
            pushdown(u);
            if(L <= mid) gao(lson, l, mid, L, R, x);
            if(mid < R) gao(rson, mid + 1, r, L, R, x);
            update(u);
        }
        inline void add(int u, int l, int r, int L, int R){
            if(l >= L && r <= R){
                if(mx[u]) mx[u]++;
                if(se[u]) se[u]++;
                sum[u] += sz[u], tag[u]++;
                return;
            }
            pushdown(u);
            if(L <= mid) add(lson, l, mid, L, R);
            if(mid < R) add(rson, mid + 1, r, L, R);
            update(u);
        }
        inline int query(int u, int l, int r, int L, int R){
            if(l >= L && r <= R) return sz[u];
            int res = 0; pushdown(u);
            if(L <= mid) res += query(lson, l, mid, L, R);
            if(mid < R) res += query(rson, mid + 1, r, L, R);
            return res;
        }
    }
    signed main(){
        read(n);
        for(int i = 1; i <= n; i++) read(a[i]);
        for(int i = 1; i <= n; i++) b[a[i]] = i;
        for(int i = 1; i <= n; i++){
            Seg::add(1, 1, n, b[i] + 1, n);
            int sz = Seg::query(1, 1, n, 1, b[i]);
            if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
            Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
            ans[i] = Seg::sum[1] + Seg::sz[1];
        }
        reverse(a + 1, a + n + 1);
        for(int i = 1; i <= n; i++) b[a[i]] = i;
        Seg::clear();
            for(int i = 1; i <= n; i++){
            Seg::add(1, 1, n, b[i] + 1, n);
            int sz = Seg::query(1, 1, n, 1, b[i]);
            if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
            Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
            ans[i] -= Seg::sz[1] * (Seg::sz[1] + 1) - Seg::sum[1];
        }
        for(int i = 1; i <= n; i++)
            printf("%lld
    ", ans[i]);
        return 0;
    }
    
原文地址:https://www.cnblogs.com/mangoyang/p/12567727.html