В настоящее время я реализую класс AVL Tree в python.
class TreeNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.height = 1
class AVLTree:
def insert(self, root, key):
# 1) Perform Normal BST Insert
if not root:
return TreeNode(key)
elif key < root.val:
root.left = self.insert(root.left, key)
else:
root.right = self.insert(root.right, key)
# 2) Update Height
root.height = max(self.get_height(root.left), self.get_height(root.right)) + 1
# 3) Get Balance Factor
balance = self.get_balance(root)
# 4) Rotate based on balance factor
if balance > 1 and root.left.val > key:
return self.right_rotate(root)
if balance > 1 and root.left.val < key:
root.left = self.left_rotate(root.left)
return self.right_rotate(root)
if balance < -1 and root.right.val < key:
return self.left_rotate(root)
if balance < -1 and root.right.val > key:
root.right = self.right_rotate(root.right)
return self.left_rotate(root)
return root
def remove(self, root, key):
print('To Remove:', key)
def getPredVal(root):
root = root.left
while root.right:
root = root.right
return root.val
def getSuccVal(root):
root = root.right
while root.left:
root = root.left
return root.val
if not root:
return root
# 1) Perform a normal remove
if root.val == key:
if not root.left and not root.right:
return None
elif root.left:
predecessor_val = getPredVal(root)
root.val = predecessor_val
root.left = self.remove(root.left, predecessor_val)
elif root.right:
successor_val = getSuccVal(root)
root.val = successor_val
root.right = self.remove(root.right, successor_val)
elif root.val < key:
root.right = self.remove(root.right, key)
elif root.val > key:
root.left = self.remove(root.left, key)
# 2) Update Height
root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
# 3) Get balance factor
balance = self.get_balance(root)
# 4) Rotate if needed
if balance > 1 and self.get_balance(root.left) >= 0:
print('RIGHT ROTATE')
print('Curr Root', root.val, root.left.val, root.left.left.val, root.left.right.val)
return self.right_rotate(root)
if balance > 1 and self.get_balance(root.right) < 0:
root.left = self.left_rotate(root.left)
return self.right_rotate(root)
if balance < -1 and self.get_balance(root.right) <= 0:
return self.left_rotate(root)
if balance < -1 and self.get_balance(root.right) > 0:
root.right = self.right_rotate(root.right)
return self.left_rotate(root)
return root
def right_rotate(self, root):
new_root = root.left
subtree = new_root.right
new_root.right = root
root.left = subtree
new_root.height = 1 + max(self.get_height(new_root.left), self.get_height(new_root.right))
root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
print('New Root', new_root.val, new_root.left.val, new_root.right.val, new_root.right.left.val)
return new_root
def left_rotate(self, root):
new_root = root.right
subtree = new_root.left
new_root.left = root
root.right = subtree
new_root.height = 1 + max(self.get_height(new_root.left), self.get_height(new_root.right))
root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
return new_root
def get_height(self, node):
if not node:
return 0
return node.height
def get_balance(self, node):
if not node:
return 0
left_height = self.get_height(node.left)
right_height = self.get_height(node.right)
return left_height - right_height
def get_min(self, root):
if not root:
return root
curr = root
while curr.left:
curr = curr.left
return curr.val
def get_max(self, root):
if not root:
return root
curr = root
while curr.right:
curr = curr.right
return curr.val
def pop_min(self, root):
min_node = self.get_min(root)
self.remove(root, min_node)
return min_node
def pop_max(self, root):
max_node = self.get_max(root)
self.remove(root, max_node)
print('MAX', max_node)
return max_node
Нет необходимости читать весь код, поскольку я реализовал его довольно тщательно. Ниже мой тестовый код.
avl_tree = AVLTree()
root = None
nums = [1,2,3,4,5]
for num in nums:
root = avl_tree.insert(root, num)
prettyPrintTree(root)
assert avl_tree.get_max(root) == 5
assert avl_tree.get_min(root) == 1
avl_tree.pop_max(root)
prettyPrintTree(root)
Первый prettyPrinTree
после добавления 1,2,3,4,5 дает мне
│ ┌── 5
└── 4
│ ┌── 3
└── 2
└── 1
(4 - это root. Это вид сбоку)
Это кажется правильным.
Однако, когда я делаю pop_max(root)
, который должен выскочить 5 и повторно сбалансировать дерево, prettyPrintTree
дает мне
└── 4
└── 3
(4 - это root, а 3 - левый дочерний элемент)
Это очень странно для меня, потому что когда я пытаюсь распечатать все узлы внутри функции right_rotate
, Я вижу
Curr Root 4 2 1 3
New Root 2 1 4 3
Но я не уверен, почему структура отличается, когда я пытаюсь распечатать дерево.
Я не вижу любая проблема в pop_max
, которая могла бы изменить структуру кода.
Любая помощь, пожалуйста?