海量数据Top-K问题,要了解从 \(O(nlog_{2}n)\) 到 \(O(n)\) 的多种方法。如果面试时遇到的面试题有多种解法,并且每种解法都各有优缺点,那么我们要向面试官问清楚题目的要求、输入的特点,从而选择最合适的解法。
Question
输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
示例 1: 示例 2: 限制: 0 <= k <= arr.length <= 10000
0 <= arr[i] <= 10000
测试用例
功能测试 (输入的数组中有相同的数字;输入的数组中没有相同的数字)。 边界测试 (输入的k等于1或者等于数组的长度)。 特殊输入测试 (k小于1;k大于数组的长度;指向数组的指针为NULL)。
本题考点
考查应聘者对时间复杂度的分析能力。面试的时候每想出一种解法,我们都要能分析出这种解法的时间复杂度是多少。 如果采用第一种思路,则本题考查应聘者对Partition函数的理解。这个函数既是快速排序的基础,也可以用来查找n个数中第k大的数字。 如果采用第二种思路,则本题考查应聘者对堆、红黑树等数据结构的理解。当需要在某个数据容器内频繁查找及替换最大值时,我们要想到二叉树是一个合适的选择,并能想到用堆或者红黑树等特殊的二叉树来实现。
Intuition
基于排序的方法
速度慢,时间复杂度为 \(O(nlog_{2}n)\) 。 基于快排将输入序列排序,取排序之后的前 \(n\) 个数字作为结果。
基于二分查找的方法
速度快,时间复杂度为 \(O(n)\) 。 从解决面试题39“数组中出现次数超过一半的数字”得到了启发,我们同样可以基于Partition
函数来解决这个问题。如果基于数组的第k个数字来调整,则使得比第k个数字小的所有数字都位于数组的左边,比第k个数字大的所有数字都位于数组的右边。这样调整之后,位于数组中左边的k个数字就是最小的k个数字(这k个数字不一定是排序的)。
基于额外空间的方法
我们可以先创建一个大小为 \(k\) 的数据容器来存储最小的 \(k\) 个数字,接下来每次从输入的n个整数中读入一个数。如果容器中已有的数字少于 \(k\) 个,则直接把这次读入的整数放入容器之中;如果容器中已有 \(k\) 个数字了,也就是容器已满,此时我们不能再插入新的数字而只能替换已有的数字。找出这已有的k个数中的最大值,然后拿这次待插入的整数和最大值进行比较。如果待插入的值比当前已有的最大值小,则用这个数替换当前已有的最大值;如果待插入的值比当前已有的最大值还要大,那么这个数不可能是最小的k个整数之一,于是我们可以抛弃这个整数。
如果我们使用一个 list
保存k个最小数,使用一个 int
保存 \(max\ number\) ,找出这已有的 \(k\) 个数中的最大值的时间开销就是 \(O(1)\) ,整体的时间复杂度就是遍历 \(n\) 个数带来的 \(O(n)\) 。这种方法不可行,必须要进行查找 ,才能确认新加入的数字和之前剩余数字中哪一个最大。 list
+ int
的方法,也避免不了查找过程。
可以选择用红黑树 和大根堆 来实现存储最小的 \(k\) 个数字。红黑树 通过把节点分为红、黑两种颜色并根据一些规则确保树在一定程度上是平衡的,从而保证在红黑树中的查找、删除和插入操作都只需要 \(O(logK)\) 时间。在 大根堆 中,根节点的值总是大于它的子树中任意节点的值。于是我们每次可以在 \(O(1)\) 时间内得到已有的 \(k\) 个数字中的最大值,但需要 \(O(logK)\) 时间完成删除及插入操作。由于单次在 \(k\) 个数字中查找得到最大数的开销是 \(O(logK)\) ,因此,对于 \(n\) 个输入数字而言,总的时间效率就是 \(O(nlogK)\) 。
对比
时间复杂度
\(O(n)\)
\(O(nlog_{2}K)\)
是否是原地算法
否
是
是否适用于海量数据
是
否
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 class Solution : def getLeastNumbers (self, arr: List[int ], k: int ) -> List[int]: _len = len (arr) if not self.checkValid(arr, _len): return None arr = self.quickSort(arr, 0 , len (arr)-1 ) print(arr) return arr[:k] def checkValid (self, arr, _len ): if arr == None or _len < 1 : return False else : return True def quickSort (self, nums, left, right ): if left >= right: return i, j, key = left, right, nums[left] while i < j: while i < j and nums[j] >= key: j -= 1 if i < j: nums[i] = nums[j] i += 1 while i < j and nums[i] <= key: i += 1 if i < j: nums[j] = nums[i] j -= 1 nums[i] = key self.quickSort(nums, left, i-1 ) self.quickSort(nums, i+1 , right) return nums
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 class Solution : def getLeastNumbers (self, nums: List[int ], k: int ) -> List[int]: _len = len (nums) if not self.checkValid(nums, _len, k): return [] start = 0 end = _len - 1 index = self.partition(nums, _len, start, end) while index != k-1 : if index > k-1 : end = index - 1 index = self.partition(nums, _len, start, end) else : start = index + 1 index = self.partition(nums, _len, start, end) output = nums[:k] return output def partition (self, nums, _len, start, end ): if nums == None or _len < 1 or start < 0 or end >= _len: return None if end == start: return end pivotvlue = nums[start] leftmark = start + 1 rightmark = end done = False while not done: while leftmark <= rightmark and nums[leftmark] <= pivotvlue: leftmark += 1 while rightmark >= leftmark and nums[rightmark] >= pivotvlue: rightmark -= 1 if leftmark > rightmark: done = True else : nums[leftmark], nums[rightmark] = nums[rightmark], nums[leftmark] nums[rightmark], nums[start] = nums[start], nums[rightmark] return rightmark def checkValid (self, nums, _len, k ): if nums == None or _len < k or _len < 1 or k < 1 : return False else : return True
基于额外空间的方法(这里使用了堆来完成k个元素的存储),自建 heap Class
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 class heap (object ): def __init__ (self ): ''' 初始化一个空堆,使用数组来在存放堆元素,节省存储 ''' self.data_list = [] def get_parent_index (self, index ): ''' 返回父节点的下标 ''' if index == 0 or index > len (self.data_list) - 1 : return None else : return (index - 1 ) >> 1 def swap (self,index_a, index_b ): ''' 交换数组中的两个元素 ''' self.data_list[index_a],self.data_list[index_b] = self.data_list[index_b],self.data_list[index_a] def insert (self, data ): ''' 先把元素放在最后,然后从后往前依次堆化 这里以大顶堆为例,如果插入元素比父节点大,则交换,直到最后 ''' self.data_list.append(data) index = len (self.data_list) - 1 parent = self.get_parent_index(index) while parent is not None and self.data_list[parent] < self.data_list[index]: self.swap(parent,index) index = parent parent = self.get_parent_index(parent) def removeMax (self ): ''' 删除堆顶元素,然后将最后一个元素放在堆顶,再从上往下依次堆化 ''' if len (self.data_list) > 0 : remove_data = self.data_list[0 ] self.data_list[0 ] = self.data_list[-1 ] del self.data_list[-1 ] self.heapify(0 ) return remove_data else : print('堆空' ) return None def heapify (self, index ): ''' 从上往下堆化,从index 开始堆化操作 (大顶堆) ''' total_index = len (self.data_list) - 1 while True : maxvalue_index = index if 2 *index + 1 <= total_index and self.data_list[2 *index + 1 ] > self.data_list[maxvalue_index]: maxvalue_index = 2 *index + 1 if 2 *index + 2 <= total_index and self.data_list[2 *index + 2 ] > self.data_list[maxvalue_index]: maxvalue_index = 2 *index + 2 if maxvalue_index == index: break self.swap(index,maxvalue_index) index = maxvalue_index def getMax (self ): """ 获取堆顶元素 """ if len (self.data_list) > 0 : return self.data_list[0 ] else : print('堆空' ) return None class Solution : def getLeastNumbers (self, nums: List[int ], k: int ) -> List[int]: _len = len (nums) if not self.checkValid(nums, _len, k): return [] _heap = heap() for num in nums: if len (_heap.data_list) < k: _heap.insert(num) else : _max = _heap.getMax() if num < _max: _heap.removeMax() _heap.insert(num) return _heap.data_list[::-1 ] def checkValid (self, nums, _len, k ): if nums == None or _len < k or _len < 1 or k < 1 : return False else : return True
基于额外空间的方法(这里使用了堆来完成k个元素的存储),调用系统的 heap Class
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import heapqclass Solution : def getLeastNumbers (self, nums: List[int ], k: int ) -> List[int]: _len = len (nums) if not self.checkValid(nums, _len, k): return [] res = [] for num in nums: if len (res) < k: res.append(num) else : res = heapq.nlargest(k, res) if num < res[0 ]: res[0 ] = num return res[::-1 ] def checkValid (self, nums, _len, k ): if nums == None or _len < k or _len < 1 or k < 1 : return False else : return True