LeetCode 99. 恢复二叉搜索树(难)

看到二叉搜索树,就要学会去利用其中序遍历有序递增的特点解题。

题目

二叉搜索树中的两个节点(从给的示例来看,交换的一般是相邻的两个节点)被错误地交换。
请在不改变其结构的情况下,恢复这棵树。
进阶:
使用 O(n) 空间复杂度的解法很容易实现。
你能想出一个只使用常数空间的解决方案吗?某些算法在最好情况下可以实现常数空间的时间复杂度完成运算,见方法二,不过仅限某些最好情况,无法保证一直是 \(O(1)\)

示例

示例 1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
输入: [1,3,null,null,2]

  1
  /
 3
  \
  2

输出: [3,1,null,null,2]

  3
  /
 1
  \
  2

示例 2:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
输入: [3,1,4,null,null,2]

3
/ \
1 4
  /
  2

输出: [2,1,4,null,null,3]

2
/ \
1 4
  /
 3

考察知识点

树、深度优先搜索

核心思想

方法一、基于中序遍历结果的方法

  • step1:获得输入二叉搜索树的中序遍历结果。
  • step2:根据二叉搜索树中序遍历递增的特点,找到中序遍历结果中降序的部分,进而找到交换了的两个节点,x和y。(这个函数是关键)
  • step3:交换x和y节点,直接修改val就行,注意遇到None部分直接continue。

时间复杂度:\(O(n)\)
空间复杂度:\(O(n)\),需要一个 list 保存中序遍历的额外结果。

方法二、迭代中序遍历

方法一是三个步骤,其中第一个步骤完成之后,肯定会有 n 次循环出来,时间复杂度最好情况也是 \(O(n)\)。我们通过迭代构造中序遍历,并在一次遍历中找到交换的节点,把三个步骤放到一起来完成,在某些特别好的情况下,就能实现时间复杂度为 \(O(1)\)
迭代顺序很简单:

  • 尽可能的向左走,然后向右走一步,重复一直到结束。
  • 若要找到交换的节点,就记录中序遍历中的最后一个节点 pred(即当前节点的前置节点),并与当前节点的值进行比较。
    • 如果当前节点的值小于前置节点 pred 的值(正常来说二叉搜索树中序遍历的节点之间必须是递增的,也也就是升序的),说明该节点和前置节点 pred 就是要交换的两个节点,记录当前节点 rooty,当前节点的前置节点 prex,。
    • 由于交换的节点只有两个,因此无论如何最多只会出现两次降序(前节点的值小于前置节点 pred 的值),当第二次遇到降序时,记录当前节点 rooty,就可以可以终止遍历。
  • 这样,就可以直接获取节点(而不仅仅是它们的值),从而实现 \(O(1)\) 的交换时间,大大减少了步骤 3 所需的时间。

时间复杂度:最好的情况下是 \(O(1)\);最坏的情况下是交换节点之一是最右边的叶节点时,此时是 \(O(N)\)
空间复杂度:最大是 \(O(H)\) 来维持栈的大小,其中 \(H\) 指的是树的高度。

Python版本

  • 方法一基于中序遍历结果方法的实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import List

# Definition for a binary tree node.
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None

def __str__(self):
return str(self.val)

# 普通二叉树按照root/left/right的顺序进行进行构建树,不对数值大小进行对比。
class BinaryTree():


# 创建二叉树 同样的一些数字,以不同的顺序输入,会影响二叉搜索树的创建。 index为要插入的位置。
def Create(self, items):
if len(items) == 0: return TreeNode(0)
nodeQueue = [] # nodeQueue保存的是这一层所有非空节点,要留着给下一层创建左右子树的时候用。
root = TreeNode(items[0]) # 创建一个根节点
nodeQueue.append(root) # 将根节点加入队列
cur = None
lineNodeNum = 2 # 记录下一层需要填充的节点的数量(注意不一定是2的幂,而是上一行中非空节点的数量乘2),第二层的节点数肯定是1*2。
startIndex = 1 # 记录当前行中数字在数组中的开始位置,第0个位置在根节点,则第第二行第一个位置肯定是items[1]。
restLength = len(items) - 1 # 记录数组中除去根节点剩余的元素的数量。
while restLength > 0: # 只要还有剩,就要继续添加。
i = startIndex
while i < startIndex + lineNodeNum: # 一次while循环会处理掉一个layer。
if i == len(items): return root # 说明已经将nums中的数字用完,此时应停止遍历,并可以直接返回root
cur = nodeQueue.pop(0) # cur是上一行的根节点。
if items[i] != None: # 如果是None,就不做处理,保留左节点为None的情况。
cur.left = TreeNode(items[i])
nodeQueue.append(cur.left)

if i + 1 == len(items): return root # 同上,做一个边界判定。
if items[i + 1] != None: # 同上,如果是None,就不做处理,保留右节点为None的情况。
cur.right = TreeNode(items[i + 1])
nodeQueue.append(cur.right)
i += 2
startIndex += lineNodeNum # 加上这一层已经处理掉的数字,就是下一层要处理的数字的起点。
restLength -= lineNodeNum # 减去这一层已经处理掉的数字,就是剩余还未处理的数字的个数。
lineNodeNum = len(nodeQueue) * 2 # nodeQueue的长度就是这一层的非空节点的个数,乘以2就是下一层需要填充的节点的数量。
return root


# 递归 前序遍历,根结点 ---> 左子树 ---> 右子树。
def PreOrder(self, root):
"""
Pre-Order Traversal of binary tree
:param root: root node of target binary tree
:return: TreeNode
"""
# if not root: return [] # 不处理None模式的一行代码
if not root: return [None] # 不处理None模式的一行代码
res = [root.val]
left_tree = self.PreOrder(root.left)
right_tree = self.PreOrder(root.right)
return res + left_tree + right_tree # 如果是递归调用二叉(搜索)树,前/中/后序遍历只用改这里的顺序就行。


class Solution:
def find_two_swapped(self, nums: List[int]) -> (int, int):
n = len(nums)
x = y = -1 # x指的是被移到前面的那个大的数,y指的是被移到后面的那个小的数。
for i in range(n - 1):
if nums[i + 1] < nums[i]:
y = nums[i + 1] # y照常后移
# first swap occurence # 第一次出现节点交换
if x == -1: x = nums[i]
# second swap occurence # 第二次出现交换节点,无论节点在哪个位置进行了交换,在这个大体上有序的序列中,都只会出现两次降序,所以x锁定第一次降序的首位数字,y照常后移,锁定第一次或者第二次降序的末尾数即可。
else: break
return x, y

# 递归 中序遍历,左子树---> 根结点 ---> 右子树。
def InOder(self, root):
"""
In-Order Traversal of binary tree
:param root: root node of target binary tree
:return: TreeNode
"""
# if not root: return [] # 不处理None模式的一行代码
if not root: return [None] # 不处理None模式的一行代码
res = [root.val]
left_tree = self.InOder(root.left)
right_tree = self.InOder(root.right)
return left_tree + res + right_tree


def recoverTree(self, root: TreeNode) -> None:
"""
Do not return anything, modify root in-place instead.
"""
# step1:中序遍历得到结果
res = self.InOder(root)
for i in res:
if i == None:
res.remove(None)

# step2:根据二叉搜索树中序遍历递增的特点,找到中序遍历结果中降序的部分,进而找到x和y。
x, y = self.find_two_swapped(res)

# step3:交换x和y节点,直接修改val就行,注意遇到None部分直接continue。
queue = []
queue.append(root)
flag_x = flag_y = False
while len(queue):
t = queue.pop(0)
if t == None:
continue
if flag_x and flag_y:
break
if not flag_x and t.val == x:
t.val = y
flag_x = True
elif not flag_y and t.val == y:
t.val = x
flag_y = True
queue.append(t.left)
queue.append(t.right)
return root


Input = [[1,3,None,None,2], [3,1,4,None,None,2]]
Answer = [[3,1,None,None,2], [2,1,4,None,None,3]]
if __name__ == "__main__":
solution = Solution()
Btree = BinaryTree()
for i in range(len(Input)):
print("-"*50)
root = Btree.Create(Input[i])
solution.recoverTree(root)
print(Btree.PreOrder(root))
answerTree = Btree.Create(Answer[i])
print(Btree.PreOrder(answerTree))

时间复杂度:\(O(n)\)
空间复杂度:\(O(n)\),需要一个 list 保存中序遍历的额外结果。

  • 方法二迭代中序遍历的实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution:
def recoverTree(self, root: TreeNode):
stack = []
x = y = pred = None
# 该方法在 函数find_two_swapped的基础上修改得到,将三个步骤合为一个步骤。
# 开始中序遍历,处理顺序是 左-> 中 -> 右
while stack or root:
while root: # 先把 "左"子树入栈
stack.append(root)
root = root.left
root = stack.pop() # 先入后出
if pred and root.val < pred.val: # 然后处理 "根"
y = root # y照常后移
if x is None: x = pred # x只指定一次就不变了
else: break
pred = root
root = root.right # 最后处理 "右"子树
x.val, y.val = y.val, x.val

时间复杂度:最好的情况下是 \(O(1)\);最坏的情况下是交换节点之一是最右边的叶节点时,此时是 \(O(N)\)
空间复杂度:最大是 \(O(H)\) 来维持递归调用堆栈的大小,其中 \(H\) 指的是树的高度。

将递归改成迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
def recoverTree(self, root):
"""
:type root: TreeNode
:rtype: void Do not return anything, modify root in-place instead.
"""
def find_two_swapped(root: TreeNode):
nonlocal x, y, pred
if root is None:
return

find_two_swapped(root.left)
if pred and root.val < pred.val:
y = root
# first swap occurence
if x is None:
x = pred
# second swap occurence
else:
return
pred = root
find_two_swapped(root.right)

x = y = pred = None
find_two_swapped(root)
x.val, y.val = y.val, x.val

时间复杂度:最好的情况下是 \(O(1)\);最坏的情况下是交换节点之一是最右边的叶节点时,此时是 \(O(N)\)
空间复杂度:最大是 \(O(H)\) 来维持递归调用堆栈的大小,其中 \(H\) 指的是树的高度。

Morris 中序遍历方法 推荐该方法,最简洁。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution:
def recoverTree(self, root):
"""
:type root: TreeNode
:rtype: void Do not return anything, modify root in-place instead.
"""
# predecessor is a Morris predecessor.
# In the 'loop' cases it could be equal to the node itself predecessor == root.
# pred is a 'true' predecessor,
# the previous node in the inorder traversal.
x = y = predecessor = pred = None

while root:
# If there is a left child
# then compute the predecessor.
# If there is no link predecessor.right = root --> set it.
# If there is a link predecessor.right = root --> break it.
if root.left:
# Predecessor node is one step left
# and then right till you can.
predecessor = root.left
while predecessor.right and predecessor.right != root:
predecessor = predecessor.right

# set link predecessor.right = root
# and go to explore left subtree
if predecessor.right is None:
predecessor.right = root
root = root.left
# break link predecessor.right = root
# link is broken : time to change subtree and go right
else:
# check for the swapped nodes
if pred and root.val < pred.val:
y = root
if x is None:
x = pred
pred = root

predecessor.right = None
root = root.right
# If there is no left child
# then just go right.
else:
# check for the swapped nodes
if pred and root.val < pred.val:
y = root
if x is None:
x = pred
pred = root

root = root.right

x.val, y.val = y.val, x.val

时间复杂度:\(O(N)\),我们访问每个节点两次。
空间复杂度:\(O(1)\)

参考链接

LeetCode官方题解