LeetCode 327. Count of Range Sum(线段树)

题目

题意:找出所有区间和在某个范围之内的个数

题解:区间问题用线段树来做。首先n^2 可以遍历所有的区间,这样会超时。

       我们用线段树,期望可以在遍历整个线段树的过程中把问题解决掉,遍历整个线段树的效率是O(n*logn) 如果遍历每个节点上的区间上所花的时间是n*logn,也可以接受,总的效率就是O(n*logn*logn)

       线段树每个节点,存储这个区间的前缀区间和,和后缀区间和,而且要是排好序的。

       父节点的区间个数,需要计算它的两个子节点中,左子节点的后缀区间和和右子节点的前缀区间和,相加有没有符合条件的。也就是两个排好序的数组,求两个数组里两个数字之和在某个范围的组合的个数。这里可以用n*logn的方法解决。
       同时父亲节点的前缀区间和和后缀区间和,也要由两个子节点得来,使用归并排序,O(n)。

       另外一定要注意,vector 超时,只能用数组了。
typedef long long int _int;

class Solution {
public:
    _int* l[10005*17];
    _int* r[10005*17];
    int sizel[10005*17];
    int sizer[10005*17];
    _int range[10005*17];
    vector<int> num;
    _int low;
    _int upp;
    _int ans;
    int countRangeSum(vector<int>& nums, int lower, int upper) {

        if(nums.size()==0)
            return 0;

        low = lower;
        upp = upper;
        num = nums;
        build(1,0,nums.size()-1);

        return ans;
    }

    int fun(int node,_int* a,int lena,_int* b,int lenb,int y,_int* &c)
    {
        _int x = range[y];
        int i=0,j=0;
        c = new _int[lena+ lenb];
        int pos=0;
        while(i<lena||j<lenb)
        {
            if(i>=lena)
            {
                c[pos++]=b[j];
                j++;
                continue;
            }
            if(j>=lenb)
            {
                c[pos++]=a[i]+x;
                i++;
                continue;
            }

            _int p = a[i]+x;
            _int q = b[j];

            if(p < q)
            {
                c[pos++]=p;
                i++;
            }
            else
            {
                c[pos++]=q;
                j++;
            }
        }
        return lena+lenb;
    }

    void pushup(int node)
    {

        sizer[node]=fun(node,r[node<<1],sizer[node<<1],r[node<<1|1],sizer[node<<1|1],node<<1|1,r[node]);
        sizel[node]=fun(node,l[node<<1|1],sizel[node<<1|1],l[node<<1],sizel[node<<1],node<<1,l[node]);

        int len = sizel[node<<1|1];
        for(int i=0;i<len;i++)
        {
            int pos2 = upper_bound(r[node<<1],r[node<<1]+sizer[node<<1],upp-l[node<<1|1][i])-r[node<<1];
            if(pos2==0)
                break;

            int pos = lower_bound(r[node<<1],r[node<<1]+sizer[node<<1],low-l[node<<1|1][i])-r[node<<1];

            ans+=pos2-pos;
        }

        range[node]=range[node<<1]+range[node<<1|1];

    }

    void build(int node,int start,int end)
    {
        if(start==end)
        {
            l[node]=new _int{num[start]};
            sizel[node]=1;
            r[node]=new _int{num[start]};
            sizer[node]=1;
            range[node]=num[start];

            if(range[node]<=upp&&range[node]>=low)
                ans++;

            return;
        }

        int mid = (start+end)>>1;
        build(node<<1,start,mid);
        build(node<<1|1,mid+1,end);

        pushup(node);
    }
};

原文地址:https://www.cnblogs.com/dacc123/p/12487157.html