Table of Contents:

AVL Tree

AVL trees are binary search trees that automatically balance themselves. This means they can guarantee O(log(n)) search times. The most confusing part of AVL trees is the means by which they stay balanced.

Python Example


def right_rotation(node):
    x = node.left
    t2 = None
    if x:
        t2 = x.right
        x.right = node
    node.left = t2
    return x

def left_rotation(node):
    y = node.right
    t2 = None
    if y:
        t2 = y.left
        y.left = node
    node.right = t2
    return y


class Node:
    def __init__(self, value, less, equal):
        self.value = value
        self.less = less
        self.equal = equal
        self.left = None
        self.right = None
        self.height = 1

    def get(self, value):
        if self.equal(value, self.value):
            return self.value
        if self.less(value, self.value):
            if self.left:
                return self.left.get(value)
            return None
        else:
            if self.right:
                return self.right.get(value)
            return None

    def has(self, value: int) -> bool:
        return self.get(value) is not None

    def get_size(self) -> int:
        s = 1
        if self.right:
            s += self.right.get_size()
        if self.left:
            s += self.left.get_size()
        return s

    def traverse(self):
        if self.left:
            for x in self.left.traverse():
                yield x
        yield self.value
        if self.right:
            for x in self.right.traverse():
                yield x
    
    def insert(self, value):
        if self.equal(value, self.value):
            return self
        if self.less(value, self.value):
            if self.left:
                self.left = self.left.insert(value)
            else:
                self.left = Node(value, less=self.less, equal=self.equal)
        else:
            if self.right:
                self.right = self.right.insert(value)
            else:
                self.right = Node(value, less=self.less, equal=self.equal)
        self.update_height()
        return self.rebalance()
    
    def update_height(self):
        x = 0
        if self.right:
            x = self.right.height
        if self.left and self.left.height > x:
            x = self.left.height
        self.height = x+1

    def height_factor(self) -> int:
        hr = height_of(self.right)
        hl = height_of(self.left)
        return hl - hr

    def rebalance(self):
        if self.left:
            self.left = self.left.rebalance()
        if self.right:
            self.right = self.right.rebalance()
        hf = self.height_factor()
        if hf > -2 and hf < 2:
            return self
        tmp = self
        if hf < -1:
            # right sub-tree is higher than left
            if self.right.height_factor() < 0:
                # RR case, left rotation
                tmp = left_rotation(self)
            else:
                # RL case
                self.right = right_rotation(self.right)
                tmp = left_rotation(self)
        if hf > 1:
            # left sub-tree is higher than right
            if self.left.height_factor() > 0:
                # LL case -> right rotation
                tmp = right_rotation(self)
            else:
                # LR case
                self.left = left_rotation(self.left)
                tmp = right_rotation(self)
        self.update_height()
        return tmp
        

def height_of(nd: Node) -> int:
    if nd:
        return nd.height
    return 0

class Tree:
    def __init__(self, less, equal):
        self.head = None
        self.less = less
        self.equal = equal

    def insert(self, value):
        if not self.head:
            self.head = Node(value, less=self.less, equal=self.equal)
        else:
            self.head = self.head.insert(value)
    
    def get(self, value: int) -> Collatz:
        if self.head:
            return self.head.get(value)
        return None

    def has(self, value: int) -> bool:
        return self.get(value) is not None

    def get_size(self) -> int:
        if self.head:
            return self.head.get_size()
        return 0

    def traverse(self):
        if not self.head:
            return None
        for x in self.head.traverse():
            yield x