基本原理

最大堆是一个二叉树,要求这个二叉树的父节点大于它的子节点,同时这个二叉树是一个完全二叉树,也就是说这个二叉树除了最底层之外的其它节点都应该被填满,最底层应该从左到右被填满。显然,最大堆的顶部节点的值是整个二叉树中最大的。

我们使用数组来构建一个最大堆,使用数组构建一个二叉树最大堆存在如下性质。假设二叉树某节点在数组中的下标索引为index,则它的父节点在数组中的下标索引为parent = (index - 1) // 2,它的左子节点的下标索引为child_left = index * 2 + 1,右子节点的下标索引为child_right = index * 2 + 2。如果计算出来parent小于0或者child大于了数组最大值,就说明没有父节点或者子节点。

代码实现

接下来我们创建最大堆类,存储一些堆的基本信息以及工具方法

1
2
3
4
5
6
7
class Heap(object):
def __init__(self) -> None:
self._data = []
def _size(self) -> int:
return len(self._data)
def _swap(self, i: int, j: int) -> None:
self._data[i], self._data[j] = self._data[j], self._data[i]

Heap类包含了三个方法,一个初始化方法创建了一个数组,_size方法返回数组的长度,_swap方法用于交换数组中的两个元素的值。

插入元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def push(self, item: int) -> None:
self._data.append(item)
# 将最后一个元素,也就是刚刚才添加的元素,上移到保证树大小有序的位置
self._siftup(self._size() - 1)

def _siftup(self, index: int) -> None:
while True:
parent = (index - 1) >> 1 # 右移一位,等价于除以二
# 如果移动到了头部,或者说父元素已经大于了当前元素,已经有序不需要再移动
if index <= 0 or self._data[parent] >= self._data[index]:
break
# 还没有序,需要将当前值和父节点的值交换,并且让父节点继续进行下一轮移动
self._swap(index, parent)
index = parent

插入元素的过程很简单,就是先把元素append到数组的末尾,之后不断地把该元素往上移,直到该元素小于父元素、或者该元素已经到了二叉树的头部,此时二叉树有序,push结束。

弹出元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def pop(self) -> int:
max_item = self._data[0]
# 将头部元素和最后一个元素做交换
self._swap(0, self._size() - 1)
# 弹出最后一个元素
self._data.pop()
# 将第一个元素,也就是从尾部置换而来的元素移动到合适的位置
self._shift_down(0)
return max_item

def _shift_down(self, index) -> None:
while True:
child = (index << 1) + 1
# 如果右侧子节点更大,就使用右侧子节点进行置换
if child + 1 < self._size() and self._data[child] < self._data[child + 1]:
child += 1
# 如果移动到了末尾,或者当前节点已经大于子节点,则可以停止移动了
if child >= self._size() or self._data[index] >= self._data[child]:
break
# 还没有序,需要将当前值和子节点的值交换,并且让子节点继续进行下一轮移动
self._swap(index, child)
index = child

弹出过程也不复杂,先把顶部的元素弹出。之后为了操作简单,我们把末尾元素和顶部元素进行交换,随后移除末尾元素。然后我们只需要不断地把顶部元素向下移动,直到它大于了自己的子节点、或者移动到了末尾,此时二叉树有序,pop结束。

总结

我们可以使用如下代码对最大堆进行测试

flat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import math
import random


class Heap(object):
def __init__(self) -> None:
self._data = []

def _size(self) -> int:
return len(self._data)

def pretty_print(self) -> None:
"""
打印二叉树
:return:
"""
height = int(math.log2(len(self._data))) + 1
for i in range(height):
width = 2 ** (height - i) - 2
print(' ' * width, end='')
blank = ' ' * (width * 2 + 2)
print(
blank.join(
['{: >2d}'.format(num) for num in self._data[2 ** i - 1:min(2 ** (i + 1) - 1, len(self._data))]]))
print()

def _swap(self, i: int, j: int) -> None:
"""
交换两个元素的值
:param i:
:param j:
:return:
"""
self._data[i], self._data[j] = self._data[j], self._data[i]

def push(self, item: int) -> None:
"""
往树中添加一个元素
:param item:
:return:
"""
self._data.append(item)
# 将最后一个元素,也就是刚刚才添加的元素,上移到保证树大小有序的位置
self._siftup(self._size() - 1)

def _siftup(self, index: int) -> None:
while True:
parent = (index - 1) >> 1 # 右移一位,等价于除以二
# 如果移动到了头部,或者说父元素已经大于了当前元素,已经有序不需要再移动
if index <= 0 or self._data[parent] >= self._data[index]:
break
# 还没有序,需要将当前值和父节点的值交换,并且让父节点继续进行下一轮移动
self._swap(index, parent)
index = parent

def pop(self) -> int:
"""
弹出顶部最大元素
:return:
"""
max_item = self._data[0]
# 将头部元素和最后一个元素做交换
self._swap(0, self._size() - 1)
# 弹出最后一个元素
self._data.pop()
# 将第一个元素,也就是从尾部置换而来的元素移动到合适的位置
self._shift_down(0)
return max_item

def _shift_down(self, index) -> None:
while True:
child = (index << 1) + 1
# 如果右侧子节点更大,就使用右侧子节点进行置换
if child + 1 < self._size() and self._data[child] < self._data[child + 1]:
child += 1
# 如果移动到了末尾,或者当前节点已经大于子节点,则可以停止移动了
if child >= self._size() or self._data[index] >= self._data[child]:
break
# 还没有序,需要将当前值和子节点的值交换,并且让子节点继续进行下一轮移动
self._swap(index, child)
index = child


if __name__ == '__main__':
range_num, num = 1000, 300
data = random.sample(range(range_num), num)

sorted_data = sorted(data, reverse=True)
heap = Heap()
for d in data:
heap.push(d)

heap.pretty_print()

for n in range(num):
v1 = sorted_data[n]
v2 = heap.pop()
assert v1 == v2, ''
print('{} -> {}'.format(v1, v2))

for n in range(10):
print('-' * 100)
heap.push(n)
heap.pretty_print()

参考

https://www.cnblogs.com/q1214367903/p/14220949.html