pygorithm/tree/avl_tree.py

147 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
二叉搜素树的问题:在多次删除和插入操作之后,二叉搜素树可能退化为链表
平衡树仲的每一个节点的平衡因子,都处于 [-1, -1] 范围内。当平衡因子为-2时执行右旋操作当平衡因子为2时执行左旋操作。
右旋步骤:
1. 右旋即失衡的节点 node 旋转向下,如果其子节点只有一个左子节点,则 node 直接作为子节点的右子节点(满足左 < 中 < 右)
2. 如有其子节点有两个节点,则 node 作为子节点的右子节点,并将原右子节点作为 node 左子节点
3. 选转完后,将选择之后的新根节点接回之前 node 的父节点
左旋只需镜像右旋操作即可
然而,在有些时候,无论左旋还是右旋都无法达到平衡,这时候需要先左旋后右旋,或者先右旋后左旋:
1. 如果是左倾树,且失衡节点的子节点的平衡因子为-1即右倾则将失衡节点的子节点先左旋再对失衡节点右旋
2. 如果是右倾树且失衡节点的子节点的平衡因子为1即左倾则将失衡节点的子节点先右旋再对失衡节点左旋
插入节点操作:
插入节点操作基本和二叉查找树类似,但每次插入时为了防止失衡,需要从这个节点开始,自底向上执行旋转操作,使所有失衡节点恢复平衡
"""
from typing import Self
class TreeNode:
"""AVL 树节点类"""
def __init__(self, val: int):
self.val: int = val # 节点值
self.height: int = 0 # 节点高度,指从该节点到其最远叶节点的高度
self.left: TreeNode | None = None # 左子节点引用
self.right: TreeNode | None = None # 右子节点引用
def get_height(self, node: Self | None) -> int:
# 获取树的高度
if node is not None:
return node.height
return -1
def update_height(self, node: Self | None):
# 节点高度等于最高子树高度 + 1
node.height = max([self.get_height(node.left), self.get_height(node.right)]) + 1
def balance_factor(self, node: Self | None) -> int:
# 获取节点的平衡因子
if node is None:
return 0
# 节点平衡因子 = 左子树高度 - 右子树高度
return self.get_height(node.left) - self.get_height(node.right)
def right_rotate(self, node: Self | None) -> Self | None:
child = node.left
grand_child = child.right
# 以 child 为原点,将 node 向右旋转
child.right = node
node.left = grand_child
# 更新节点高度
self.update_height(node)
self.update_height(child)
# 返回旋转后的子树的根节点
return child
def left_rotate(self, node: Self | None) -> Self | None:
child = node.right
grand_child = child.left
# 以 child 为原点,将 node 向左旋转
child.left = node
node.right = grand_child
# 更新节点高度
self.update_height(node)
self.update_height(child)
# 返回旋转后的子树的根节点
return child
def rotate(self, node: Self | None) -> Self | None:
balance_factor = self.balance_factor(node)
# 左倾树
if balance_factor > 1:
if self.balance_factor(node.left) >= 0:
return self.right_rotate(node)
else:
# 先左旋后右旋
node.left = self.left_rotate(node.left)
return self.right_rotate(node)
# 右倾树
elif balance_factor < -1:
if self.balance_factor(node.right) <= 0:
return self.left_rotate(node)
else:
node.right = self.right_rotate(node.right)
return self.left_rotate(node)
return node
def insert(self, val):
self._root = self.insert_helper(self._root, val)
def insert_helper(self, node: Self | None, val: int) -> Self:
"""递归插入节点"""
if node is None:
# 作为根节点
return TreeNode(val)
# 查找插入位置:
if val < node.val:
# 递归左子树插入
# 接收旋转之后返回的根节点,并将其设为左子节点
node.left = self.insert_helper(node.left, val)
elif val > node.val:
# 递归右子树插入
# 接收旋转之后返回的根节点,并将其设为左子节点
node.right = self.insert_helper(node.right, val)
else:
# 重复节点不插入
return node
# 插入之后,更新节点高度
self.update_height(node)
# 返回旋转之后的节点,如果未发生旋转则还是 node如果发生了旋转则是 node 的其中一个子节点
return self.rotate(node)
def remove(self, val: int):
self._root = self.remove_helper(self._root, val)
def remove_helper(self, node: Self | None, val: int) -> Self:
if node is None:
return None
if val < node.val:
node.left = self.remove_helper(node.left, val)
elif val > node.val:
node.right = self.remove_helper(node.right, val)
else:
if node.left is None or node.right is None:
child = node.left or node.right
if child is None:
return None
else:
node = child
else:
temp = node.right
while temp.left is not None:
temp = temp.left
node.right = self.remove_helper(node.right, temp.val)
node.val = temp.val
self.update_height(node)
return self.rotate(node)