[LeetCode] Kth Largest in an Array | 数组第K大元素

leetcode 215. kth largest in an array

https://leetcode.com/problems/kth-largest-element-in-an-array/?tab=Description

方法1:直接按从大到小排序,返回第k个。

Time: O(n logn)
Extra space: O(1)

int findKthLargest(vector<int> nums, int k) {
    sort(nums.begin(), nums.end(), greater<int>());
    return nums[k-1];
}

方法2:维护一个Priority Queue / Min Heap,扫一遍数组不断往队列里push元素,当队列的size大于k时就pop,最后队列里一定还剩k个元素,而且是前k大的元素(因为是最小堆,pop出去的是最小的n-k个元素),这时队列头部(堆顶)的就是第k大。

Time: O(n logn)
Extra space: O(n)

int findKthLargest(vector<int> nums, int k) {
    priority_queue<int, vector<int>, greater<int> > min_heap;
    for (auto& num : nums) {
        min_heap.push(num);
        if (min_heap.size() > k) min_heap.pop();
    }
    return min_heap.top();
}

方法3:借用Quick-Sort的Partition思想,因为Partition之后数组整体有序,且有一个元素的位置直到数组全部有序都不会改变,如果这个位置恰好就是目标位置,说明我们找到了第k大。不过这里就不需要像快排那样对两边都排序,采用类似于折半的思想去做即可。

Time: 平均O(n),最坏O(n^2)
Extra space: O(1)

class Solution {
public:
    int findKthLargest(vector<int> nums, int k)
    {
        int target_pos = nums.size() - k;
        int start = 0, end = nums.size() - 1;
        int mid = partition(nums, start, end);
        while (mid != target_pos) {
            if (mid < target_pos)
                start = mid + 1;
            else
                end = mid - 1;
            mid = partition(nums, start, end);
        }
        return nums[mid];
    }
    
    int partition(vector<int>& nums, int start, int end) {
        int pivot = nums[end], fst_larger = start;
        for (int i = start; i < end; ++i) {
            if (nums[i] < pivot) {
                swap(nums[i], nums[fst_larger++]);
            }
        }
        swap(nums[end], nums[fst_larger]);
        return fst_larger;
    }
};

对于极端的input case,方法3可能会有O(n^2)的表现,原因在于算法性能对pivot比较敏感。而pivot的选择是一项技术活,通常的方案是采用randomly select、三位置(首、尾、中间)取均值等手段来提升极端input下的性能,详细的解释参见: https://en.wikipedia.org/wiki/Quickselect

int partition(vector<int>& nums, int start, int end) {
    int pivot_idx = start + floor(rand() % (end - start + 1));  // randomly select
    swap(nums[pivot_idx], nums[end]);  // move pivot to end
    
    int pivot = nums[end], fst_larger = start;
    for (int i = start; i < end; ++i) {
        if (nums[i] < pivot) {
            swap(nums[i], nums[fst_larger++]);
        }
    }
    swap(nums[end], nums[fst_larger]);
    return fst_larger;
}

ps. 随机pivot方法跑leetcode所有测试数据大约8ms(C++)

原文地址:https://www.cnblogs.com/ilovezyg/p/6483358.html