可持久化线段树

可持久化线段树

简介

可持久化数据结构又称函数式数据结构,其思路来自于函数式编程。在函数式编程中,变量的值是不允许改变的,因而每一次插入元素都必须创建一个新的版本。

设想一棵二叉树:

        [1]
    [2]     [3]
 [4]  [5] [6]  

现在为了插入一个新节点,我们必须新建一棵树

        (1)
    (2)     (3)
 (4)  (5) (6)  (7)

不难发现,很多元素被重复使用了。如果将重复的元素合并,就得到这样一棵树:

        [1]    --->  (1)
    [2]     [3]   [2]   (3)
 [4]  [5] [6]         [6]  (7)

新建的元素其实只有O(h),如果是一棵平衡树或线段树,新建元素就是O(lgn)

应用

可持久化线段树是解决区间问题的锐利武器。考虑第i棵和第j棵线段树Ti,Tj,如果他们的对应元素相减得到一棵新树TjTi,这棵树其实就是区间 [i+1,j] 所对应的线段树。

例如vijos1459车展一题。用反证法不难证明题目中要求的即是

i=lr|ximid|

其中,mid为区间 x[l,r] 的中位数。

由于涉及了区间中位数,可以考虑使用树套树实现。但树套树代码复杂度较高且不宜于调试,可以考虑用可持久化线段树代替。

将输入的 xi 按顺序建立一棵可持久化线段树,分别维护sumnum_sum,第一个为区间内元素的和,第二个为区间内元素出现的次数。利用 TrTl1 得到区间 [l,r] 内的线段树来计算。

Code

// 可持久化线段树
// 维护两个值
#include <bits/stdc++.h>
using namespace std;

#define maxn 1005
struct node {
    int l, r, lc, rc;
    long long sum;
    int num_sum;
    node(){l = r = lc = rc = sum = num_sum = 0; }
}tree[15*maxn];
int root[200005], top = 0;
int n, m;

inline long long read()
{
    long long a = 0; int c;
    do c = getchar(); while(!isdigit(c));
    while (isdigit(c)) {
        a = a*10 + c - '0';
        c = getchar();
    }
    return a;
}

int sorted[1005]; // 离散化
int dat[1005]; // 原始数据

inline void update(int i) {
    tree[i].sum = tree[tree[i].lc].sum + tree[tree[i].rc].sum;
    tree[i].num_sum = tree[tree[i].lc].num_sum + tree[tree[i].rc].num_sum;
}

inline int new_node(int l, int r) {
    tree[++top].l = l;
    tree[top].r = r;
    return top;
}

void build(int &nd, int l, int r) {
    if (l > r) return;
    if (l == r) {nd = new_node(l, r);return;}
    int mid = (l+r)>>1;
    nd = new_node(l, r);
    build(tree[nd].lc, l, mid);
    build(tree[nd].rc, mid+1, r);
}

void insert(int pre, int &now, int k, long long dat) {
    if (tree[pre].l == tree[pre].r) {
        now = new_node(k, k);
        tree[now].sum = dat;
        tree[now].num_sum = 1;
    } else {
        now = new_node(tree[pre].l, tree[pre].r);
        tree[now] = tree[pre];
        if (k <= tree[tree[pre].lc].r) insert(tree[pre].lc, tree[now].lc, k, dat);
        else insert(tree[pre].rc, tree[now].rc, k, dat);
        update(now);
    }
}

// 查找区间和(sum)
long long get_sum(int pre, int now, int l, int r)
{
    if (l > r || !pre || !now) return 0;
    if (l == tree[pre].l && r == tree[now].r) return tree[now].sum - tree[pre].sum;
    return get_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +
           get_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);
}

// 区间内数字个数的和
int get_num_sum(int pre, int now, int l, int r)
{
    if (l > r || !pre || !now) return 0;
    if (l == tree[pre].l && r == tree[now].r) return tree[now].num_sum - tree[pre].num_sum;
    return get_num_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +
           get_num_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);
}

int find_mid(int pre, int now, int k) // 查找中位数(第k个数)的位置
{
    if (tree[now].l == tree[now].r) return tree[now].l;
    if (tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum >= k)
        return find_mid(tree[pre].lc, tree[now].lc, k);
    else
        return find_mid(tree[pre].rc, tree[now].rc, k-(tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum));
}

// 查询区间
long long query(int l, int r) {
    int pos = find_mid(root[l-1], root[r], ((l+r)>>1)-l+1);
    long long lft = get_sum(root[l-1], root[r], 1, pos);int ln = get_num_sum(root[l-1], root[r], 1, pos);
    long long rgt = get_sum(root[l-1], root[r], pos+1, n);int rn = get_num_sum(root[l-1], root[r], pos+1, n);
    return rgt - rn*sorted[pos] + ln*sorted[pos] - lft;
}

void dfs(int rt, int tab = 0) {
    if (rt) {
        for (size_t i = 0; i < tab; i++) putchar(' ');
        cout << tree[rt].l << "->" << tree[rt].r << " " << tree[rt].sum << " " << tree[rt].num_sum << endl;
        dfs(tree[rt].lc, tab+2);
        dfs(tree[rt].rc, tab+2);
    }
}

int main()
{
    n = read(); m = read();
    build(root[0], 1, n);
    long long a, l, r;
    for (size_t i = 1; i <= n; i++)
        sorted[i] = dat[i] = read();
    sort(sorted+1, sorted+n+1);
    for (size_t i = 1; i <= n; i++) {
        insert(root[i-1], root[i], lower_bound(sorted+1, sorted+n+1, dat[i])-sorted, dat[i]);
    }
    long long ans = 0;
    for (size_t i = 1; i <= m; i++) {
        l = read(); r = read();
        ans += query(l, r);
    }
    cout << ans << endl;
    return 0;
}
原文地址:https://www.cnblogs.com/ljt12138/p/6684357.html