剑指offer 面试题40. 最小的k个数(易)

海量数据Top-K问题,要了解从 \(O(nlog_{2}n)\)\(O(n)\) 的多种方法。如果面试时遇到的面试题有多种解法,并且每种解法都各有优缺点,那么我们要向面试官问清楚题目的要求、输入的特点,从而选择最合适的解法。

Question

输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。

示例 1:

1
2
输入:arr = [3,2,1], k = 2
输出:[1,2] 或者 [2,1]
示例 2:
1
2
输入:arr = [0,1,2,1], k = 1
输出:[0]
限制: 0 <= k <= arr.length <= 10000 0 <= arr[i] <= 10000

测试用例

功能测试(输入的数组中有相同的数字;输入的数组中没有相同的数字)。 边界测试(输入的k等于1或者等于数组的长度)。 特殊输入测试(k小于1;k大于数组的长度;指向数组的指针为NULL)。

本题考点

考查应聘者对时间复杂度的分析能力。面试的时候每想出一种解法,我们都要能分析出这种解法的时间复杂度是多少。 如果采用第一种思路,则本题考查应聘者对Partition函数的理解。这个函数既是快速排序的基础,也可以用来查找n个数中第k大的数字。 如果采用第二种思路,则本题考查应聘者对堆、红黑树等数据结构的理解。当需要在某个数据容器内频繁查找及替换最大值时,我们要想到二叉树是一个合适的选择,并能想到用堆或者红黑树等特殊的二叉树来实现。

Intuition

基于排序的方法

速度慢,时间复杂度为 \(O(nlog_{2}n)\)基于快排将输入序列排序,取排序之后的前 \(n\) 个数字作为结果。

基于二分查找的方法

速度快,时间复杂度为 \(O(n)\)从解决面试题39“数组中出现次数超过一半的数字”得到了启发,我们同样可以基于Partition 函数来解决这个问题。如果基于数组的第k个数字来调整,则使得比第k个数字小的所有数字都位于数组的左边,比第k个数字大的所有数字都位于数组的右边。这样调整之后,位于数组中左边的k个数字就是最小的k个数字(这k个数字不一定是排序的)。

基于额外空间的方法

我们可以先创建一个大小为 \(k\) 的数据容器来存储最小的 \(k\) 个数字,接下来每次从输入的n个整数中读入一个数。如果容器中已有的数字少于 \(k\) 个,则直接把这次读入的整数放入容器之中;如果容器中已有 \(k\) 个数字了,也就是容器已满,此时我们不能再插入新的数字而只能替换已有的数字。找出这已有的k个数中的最大值,然后拿这次待插入的整数和最大值进行比较。如果待插入的值比当前已有的最大值小,则用这个数替换当前已有的最大值;如果待插入的值比当前已有的最大值还要大,那么这个数不可能是最小的k个整数之一,于是我们可以抛弃这个整数。

如果我们使用一个 list 保存k个最小数,使用一个 int 保存 \(max\ number\),找出这已有的 \(k\) 个数中的最大值的时间开销就是 \(O(1)\),整体的时间复杂度就是遍历 \(n\) 个数带来的 \(O(n)\)这种方法不可行,必须要进行查找,才能确认新加入的数字和之前剩余数字中哪一个最大。 list + int 的方法,也避免不了查找过程。

可以选择用红黑树大根堆来实现存储最小的 \(k\) 个数字。红黑树 通过把节点分为红、黑两种颜色并根据一些规则确保树在一定程度上是平衡的,从而保证在红黑树中的查找、删除和插入操作都只需要 \(O(logK)\) 时间。在 大根堆 中,根节点的值总是大于它的子树中任意节点的值。于是我们每次可以在 \(O(1)\)时间内得到已有的 \(k\) 个数字中的最大值,但需要 \(O(logK)\) 时间完成删除及插入操作。由于单次在 \(k\) 个数字中查找得到最大数的开销是 \(O(logK)\),因此,对于 \(n\) 个输入数字而言,总的时间效率就是 \(O(nlogK)\)

对比

基于Partition的算法 基于红黑树或者堆的算法
时间复杂度 \(O(n)\) \(O(nlog_{2}K)\)
是否是原地算法
是否适用于海量数据

Code

  • 基于排序的方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution:
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
_len = len(arr)
if not self.checkValid(arr, _len):
return None
arr = self.quickSort(arr, 0, len(arr)-1)
print(arr)
return arr[:k]

def checkValid(self, arr, _len):
if arr == None or _len < 1:
return False
else:
return True

def quickSort(self, nums, left, right):
if left >= right:
return
i, j, key = left, right, nums[left]
while i < j:
# 必须先从右向左找第一个小于等于key的值
while i < j and nums[j] >= key: # 从右向左找第一个小于等于key的值
j -= 1
if i < j: # 找到了小于等于key的就扔到前面
nums[i] = nums[j]
i += 1

while i < j and nums[i] <= key: # 从左向右找第一个大于等于key的值
i += 1
if i < j: # 找到大于等于key的就扔到后面
nums[j] = nums[i]
j -= 1

nums[i] = key
self.quickSort(nums, left, i-1)
self.quickSort(nums, i+1, right)
return nums
  • 基于二分查找的方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Solution:
# O(n)的算法, 只有当我们可以修改输入的数组时可用
# 基于Partition的方法
def getLeastNumbers(self, nums: List[int], k: int) -> List[int]:
_len = len(nums)
if not self.checkValid(nums, _len, k):
return []
start = 0
end = _len - 1
index = self.partition(nums, _len, start, end)
# 一直二分查找,直到index为k-1才停止。
while index != k-1:
# 如果index大于k-1,就去排序前面部分。
if index > k-1:
end = index - 1
index = self.partition(nums, _len, start, end)
# 如果index小于k-1,就去排序后面部分
else:
start = index + 1
index = self.partition(nums, _len, start, end)

# 到这里nums的前k个元素就已经排序完成了
output = nums[:k]
# output.sort() # 这一行可要可不要
return output

# 划分函数,只进行一次二分查找,不递归。
def partition(self, nums, _len, start, end):
if nums == None or _len < 1 or start < 0 or end >= _len:
return None
if end == start:
return end
pivotvlue = nums[start]
leftmark = start + 1
rightmark = end

done = False
# 开始将小于pivotvlue的元素移到右边,大于pivotvlue的元素移到左边。
while not done:
# 从左向右找到一个大于pivotvlue的元素
# leftmark <= rightmark必须放到nums[leftmark] <= pivotvlue前面
while leftmark <= rightmark and nums[leftmark] <= pivotvlue:
leftmark += 1
# 从右向左找到一个小于pivotvlue的元素
while rightmark >= leftmark and nums[rightmark] >= pivotvlue:
rightmark -= 1
# 交换这两个元素
if leftmark > rightmark:
done = True
else:
nums[leftmark], nums[rightmark] = nums[rightmark], nums[leftmark]
nums[rightmark], nums[start] = nums[start], nums[rightmark]
return rightmark

def checkValid(self, nums, _len, k):
if nums == None or _len < k or _len < 1 or k < 1:
return False
else:
return True
  • 基于额外空间的方法(这里使用了堆来完成k个元素的存储),自建 heap Class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# heap Class
class heap(object):
def __init__(self):
'''
初始化一个空堆,使用数组来在存放堆元素,节省存储
'''
self.data_list = []

def get_parent_index(self, index):
'''
返回父节点的下标
'''
if index == 0 or index > len(self.data_list) - 1:
return None
else:
return (index - 1) >> 1

def swap(self,index_a, index_b):
'''
交换数组中的两个元素
'''
self.data_list[index_a],self.data_list[index_b] = self.data_list[index_b],self.data_list[index_a]


def insert(self, data):
'''
先把元素放在最后,然后从后往前依次堆化
这里以大顶堆为例,如果插入元素比父节点大,则交换,直到最后
'''
self.data_list.append(data)
index = len(self.data_list) - 1
parent = self.get_parent_index(index)
# 循环,直到该元素成为堆顶,或小于父节点(对于大顶堆)
while parent is not None and self.data_list[parent] < self.data_list[index]:
# 交换操作
self.swap(parent,index)
index = parent
parent = self.get_parent_index(parent)

def removeMax(self):
'''
删除堆顶元素,然后将最后一个元素放在堆顶,再从上往下依次堆化
'''
if len(self.data_list) > 0:
remove_data = self.data_list[0]
self.data_list[0] = self.data_list[-1]
del self.data_list[-1]
# 堆化
self.heapify(0)
return remove_data
else:
print('堆空')
return None


def heapify(self, index):
'''
从上往下堆化,从index 开始堆化操作 (大顶堆)
'''
total_index = len(self.data_list) - 1
while True:
maxvalue_index = index
if 2*index + 1 <= total_index and self.data_list[2*index + 1] > self.data_list[maxvalue_index]:
maxvalue_index = 2*index + 1
if 2*index + 2 <= total_index and self.data_list[2*index + 2] > self.data_list[maxvalue_index]:
maxvalue_index = 2*index + 2
if maxvalue_index == index:
break
self.swap(index,maxvalue_index)
index = maxvalue_index

def getMax(self):
"""
获取堆顶元素
"""
if len(self.data_list) > 0:
return self.data_list[0]
else:
print('堆空')
return None


class Solution:
# O(nlogk)的算法, 适合海量数据
# 利用一个k容量的容器存放数组, 构造最大堆, 当下一个数据大于最大数, 跳过, 小于最大数, 则进入容器替换之前的最大数
def getLeastNumbers(self, nums: List[int], k: int) -> List[int]:
_len = len(nums)
if not self.checkValid(nums, _len, k):
return []
_heap = heap()
for num in nums:
if len(_heap.data_list) < k:
_heap.insert(num)
else:
_max = _heap.getMax()
if num < _max:
_heap.removeMax()
_heap.insert(num)
return _heap.data_list[::-1] # 如果用的小根堆,就不用逆序。

def checkValid(self, nums, _len, k):
if nums == None or _len < k or _len < 1 or k < 1:
return False
else:
return True
  • 基于额外空间的方法(这里使用了堆来完成k个元素的存储),调用系统的 heap Class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import heapq
class Solution:
# O(nlogk)的算法, 适合海量数据
# 利用一个k容量的容器存放数组, 构造最大堆, 当下一个数据大于最大数, 跳过, 小于最大数, 则进入容器替换之前的最大数
def getLeastNumbers(self, nums: List[int], k: int) -> List[int]:
_len = len(nums)
if not self.checkValid(nums, _len, k):
return []
res = []
for num in nums:
if len(res) < k:
res.append(num)
else:
res = heapq.nlargest(k, res)
if num < res[0]:
res[0] = num
return res[::-1]

def checkValid(self, nums, _len, k):
if nums == None or _len < k or _len < 1 or k < 1:
return False
else:
return True