pygorithm/tree/bin_search_tree.py

124 lines
3.8 KiB
Python

"""
二叉搜素树满足一下条件:
1. 对于根节点,左子树中所有节点的值 小于 根节点的值 小于 右子树中所有节点的值。
2. 任意节点的左、右子树也是二叉搜索树,即同样满足条件 1. 。
查找节点:
1. 从根节点查找,循环比较节点值和要查找的值大小
2. 若节点值小于查找值,则走右子树
3. 若节点值大于查找值,则走左子树
"""
from typing import Self
class BinarySearchTree:
def __init__(self, val: int = None, key: int = None):
self.val: int | None = val
self.key: int | None = key
self.left: BinarySearchTree | None = None
self.right: BinarySearchTree | None = None
def search_val(self, key: int) -> int | None:
if self.key is None:
return None
cur = self
while cur:
if cur.key == key:
return cur.val
if cur.key > key:
cur = cur.left
else:
cur = cur.right
return None
def insert(self, key: int, val: int):
if self.val is None:
self.key = key
self.val = val
return
cur, pre = self, None
while cur:
if cur.key == key:
cur.val = val
return
if cur.key > key:
if cur.left is None:
cur.left = BinarySearchTree(key, val)
return
else:
cur, pre = cur.left, cur
else:
if cur.right is None:
cur.right = BinarySearchTree(key, val)
return
else:
cur, pre = cur.right, cur
if pre.key < key:
pre.right = BinarySearchTree(key, val)
else:
pre.left = BinarySearchTree(key, val)
def remove(self, key: int) -> int | None:
if self.key is None:
return None
cur, pre = self, None
while cur:
if cur.key == key:
break
if cur.key > key:
cur, pre = cur.left, cur
else:
cur, pre = cur.right, cur
if cur is None:
# 没有找到要删除的节点
return
if cur.left is None or cur.right is None:
# 子节点数量有 0 个或 1 个
child: BinarySearchTree | None = cur.left or cur.right
if cur.key != self.key:
# 不是根节点
if pre.left.key == cur.key:
# 查找到的当前节点是前一个节点的左子节点
pre.left = child
else:
pre.right = child
else:
if child:
self.key = child.key
self.val = child.val
else:
self.key = None
self.val = None
else:
# 子节点数量有两个, 无法直接删除它,而需要使用一个节点替换该节点
# 由于要保持二叉搜索树“左子树 < 根节点 < 右子树”的性质,
# 因此这个节点可以是右子树的最小节点或左子树的最大节点。
temp, pre1 = cur.right, cur
# 找右子树的最小节点
while temp.left:
temp, pre1 = temp.left, temp
if pre.left.key == cur.key:
pre.left = temp
pre1.left = None
else:
pre.right = temp
pre1.left = None
def inorder(self, bst: Self):
# 中序遍历的结果是升序
if bst.left:
self.inorder(self.left)
print(self.val)
if bst.right:
self.inorder(self.right)