树状数组 Binary Indexed Tree/Fenwick Tree

2018-03-25 17:29:29

树状数组是一个比较小众的数据结构,主要应用领域是快速的对mutable array进行区间求和。

对于一般的一维情况下的区间和问题,一般有以下两种解法:

1)DP

预处理:建立长度为n的数组,每个结点i保存前i个数的和,时间复杂度O(n)。

查询:直接从数组中取两个段相减,时间复杂度O(1)。

更新:这种方法比较适用与immutable数组,对于mutable数组的更新需要重新建立表,所以时间复杂度为O(n)。

2)树状数组 BIT

预处理:建立树状数组,对数组中的每个数进行update操作,时间复杂度O(nlogn)。

查询:从当前结点向根结点遍历,求总和,时间复杂度O(logn)。

更新:从当前结点向根结点遍历,更新这条路径上的所有结点的值,时间复杂度O(logn)。

一、一维树状数组

在DP算法中我们在每个结点存放的是当前结点与前面所有结点的和,这就直接导致了在做更新的时候我们也只能进行大规模的修正,树状数组的提出就是为了在更新的过程中也保证有较低的时间复杂度,要实现这个目的,显然,每个结点我们不能再存储全部的前i个数,只能进行部分存储,而这部分存储的个数,就是整个树状数组的核心和关键。

在树状数组中,每个结点存的和的个数是lowbit(i)个,这里的lowbit就是i的二进制表示的第一个1所表示的数,举例:4(0100),lowbit(4) = 100,也就是在4号结点位置要存储4个结点的和。而这3个数(自己除外)都是4号结点的子孙,这三个数显然是0011,0010,0001。

那么,parent结点和child结点到底有什么关系呢?在树状数组中,我们规定parent = child + lowbit(child)。在更新操作中,我们可以递归的向上遍历,将所有该结点的父亲结点都进行更新,时间复杂度为O(logn)。

那么,又如何进行查询呢?对于sum(1, j) = nums[1] + nums[2] + ... + nums[j]。由于在树状数组中j号结点中存了lowbit(j)个数的和,所以原式可以写成sum(1, j) = sum(1, j - lowbit(j)) + BIT(j)。因此也可以进行递归或者迭代的求解。更进一步的分析,我们可以得知在查询的过程中,其实也生成了一组树,如下图。

public class BinaryIndexedTree1D {
    int[] BIT;

    BinaryIndexedTree1D(int n) {
        this.BIT = new int[n + 1];
    }

    void update(int index, int delta) {
        for (int i = index; i < BIT.length; i += (i & -i)) {
            BIT[i] += delta;
        }
    }

    int query(int index) {
        int res = 0;
        for (int i = index; i > 0; i -= (i & -i)) {
            res += BIT[i];
        }
        return res;
    }
}

问题描述:

问题求解:

public class NumArray {
    FenwickTree ft;
    int[] ls;

    public NumArray(int[] nums) {
        ft = new FenwickTree(nums.length);
        ls = nums;
        for (int i = 0; i < ls.length; i++) {
            ft.update(i + 1, ls[i]);
        }
    }

    public void update(int i, int val) {
        ft.update(i + 1, val - ls[i]);
        ls[i] = val;
    }

    public int sumRange(int i, int j) {
        return ft.query(j + 1) - ft.query(i);
    }
}

class FenwickTree {
    int[] BIT;

    FenwickTree(int n) {
        this.BIT = new int[n + 1];
    }

    void update(int index, int delta) {
        for (int i = index; i < BIT.length; i += (i & -i)) {
            BIT[i] += delta;
        }
    }

    int query(int index) {
        int res = 0;
        for (int i = index; i > 0; i -= (i & -i)) {
            res += BIT[i];
        }
        return res;
    }
}

2019.04.28

class NumArray {
    int[] bit;
    int[] numsCopy;
    int n;

    public NumArray(int[] nums) {
        bit = new int[nums.length + 1];
        numsCopy = new int[nums.length];
        n = nums.length + 1;
        for (int i = 0; i < nums.length; i++) {
            update(i, nums[i]);
            numsCopy[i] = nums[i];
        }
    }
    
    public void update(int i, int val) {
        int idx = i + 1;
        int delta = val - numsCopy[i];
        numsCopy[i] = val;
        for (int k = ++i; k < n; k += (k & -k)) {
            bit[k] += delta;
        }
    }
    
    private int query(int i) {
        int res = 0;
        for (int k = i; k > 0; k -= (k & -k)) {
            res += bit[k];
        }
        return res;
    }
    
    public int sumRange(int i, int j) {
        return query(j + 1) - query(i);
    }
}

二、二维树状数组

二维的树状数组其实就是分别对每行每列进行树状数组化,编写代码上面和一维数组是非常类似的。

public class BinaryIndexedTree2D {
    int[][] BIT;

    BinaryIndexedTree2D(int N, int M) {
        BIT = new int[N + 1][M + 1];
    }

    void update(int n, int m, int delta) {
        for (int i = n; i < BIT.length; i += (i & -i)) {
            for (int j = m; j < BIT[0].length; j += (j & -j)) {
                BIT[i][j] += delta;
            }
        }
    }

    int query(int n, int m) {
        int res = 0;
        for (int i = n; i > 0; i -= (i & -i)) {
            for (int j = m; j > 0; j -= (j & -j)) {
                res += BIT[i][j];
            }
        }
        return res;
    }
}

问题描述:

问题求解:

裸的2维树状数组问题,有个坑是所有的数字都需要对1e9进行取模,并且最终的sum结果不能为负数。

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        int mod = (int)1e9 + 7;
        Scanner sc = new Scanner(System.in);
        String s = sc.nextLine();
        int N = Integer.valueOf(s.split(" ")[0]);
        int M = Integer.valueOf(s.split(" ")[1]);
        BinaryIndexedTree2D bit = new BinaryIndexedTree2D(N, N);
        for (int k = 0; k < M; k++) {
            s = sc.nextLine();
            String[] ls = s.split(" ");
            if (ls[0].equals("Add")) {
                int i = Integer.valueOf(ls[1]);
                int j = Integer.valueOf(ls[2]);
                int val = Integer.valueOf(ls[3]);
                bit.update(i + 1, j + 1, val);
            }
            else if (ls[0].equals("Sum")) {
                int x1 = Integer.valueOf(ls[1]);
                int y1 = Integer.valueOf(ls[2]);
                int x2 = Integer.valueOf(ls[3]);
                int y2 = Integer.valueOf(ls[4]);
                System.out.println((bit.query(x2 + 1, y2 + 1) - bit.query(x2 + 1, y1) - bit.query(x1, y2 + 1) + bit.query(x1, y1) + mod) % mod);
            }
        }
    }
}

class BinaryIndexedTree2D {
    int[][] BIT;
    int mod;

    BinaryIndexedTree2D(int N, int M) {
        BIT = new int[N + 1][M + 1];
        mod = (int)1e9 + 7; 
    }

    void update(int n, int m, int delta) {
        for (int i = n; i < BIT.length; i += (i & -i)) {
            for (int j = m; j < BIT[0].length; j += (j & -j)) {
                BIT[i][j] = (BIT[i][j] + delta) % mod;
            }
        }
    }

    int query(int n, int m) {
        int res = 0;
        for (int i = n; i > 0; i -= (i & -i)) {
            for (int j = m; j > 0; j -= (j & -j)) {
                res = (res + BIT[i][j]) % mod;
            }
        }
        return res;
    }
}

三、逆序对问题

逆序对问题是一个经典的问题,使用树状数组可以很好的解决这个问题。

树状数组的核心是单点更新,区间求和。

问题的核心就变成了如何将逆序对问题转化程区间求和的问题,简单的转化方式有,构建一个频率计数桶,将出现的元素放到相应的桶中,并将桶中的数量加一。从后向前逆序遍历数组,边遍历边更新桶中的数量,当遍历到一个元素的时候,计算getSum(num - 1)就可以得到当前元素的逆序对个数。

这个方法的问题就是单纯采用数字的大小来建立桶的话,这个桶的范围可能会很大,其实我们需要的只是相对的大小,所以我们可以将nums mapping 到sort后的idx上,这样整个的空间复杂度就降到了unique num的数量级。

问题描述:

问题求解:

    public List<Integer> countSmaller(int[] nums) {
        List<Integer> res = new ArrayList<>();
        int[] sorted = Arrays.copyOf(nums, nums.length);
        Arrays.sort(sorted);
        Map<Integer, Integer> ranks = new HashMap<>();
        int rank = 0;
        for (int i = 0; i < sorted.length; ++i)
            if (i == 0 || sorted[i] != sorted[i - 1])
                ranks.put(sorted[i], ++rank);
        int[] bit = new int[ranks.size() + 1];
        for (int i = nums.length - 1; i >= 0; --i) {
            int sum = query(bit, ranks.get(nums[i]) - 1);
            res.add(sum);
            update(bit, ranks.get(nums[i]), 1);
        }
        Collections.reverse(res);
        return res;
    }

    private int query(int[] bit, int i) {
        int res = 0;
        for (int k = i; k > 0; k -= (k & -k)) {
            res += bit[k];
        }
        return res;
    }

    private void update(int[] bit, int i, int delta) {
        for (int k = i; k < bit.length; k += (k & -k)) {
            bit[k] += delta;
        }
    }

  

原文地址:https://www.cnblogs.com/hyserendipity/p/8645723.html