After not programming in c++ for some time, I decided to write an AVL tree implementation to get back in shape (I wasn't that good anyway. Still an amateur).
Header File
bst.hpp
#pragma once
#include <iostream>
#include <utility>
namespace bst
{
template <typename T>
class BstNode
{
public:
explicit BstNode(T key, BstNode<T> *parent = nullptr);
BstNode(const BstNode<T> &other);
BstNode<T> &operator=(BstNode<T> &other);
~BstNode();
BstNode<T> *insert(T key);
BstNode<T> *getNode(const T &key);
const BstNode<T> *search(const T &key) const;
BstNode<T> *findMin();
BstNode<T> *findMax();
BstNode<T> *nextLarger();
void printInOrder() const;
int placeInOrder(T array[], int begin) const;
int findHeight() const;
void deleteChildren();
T key;
BstNode<T> *left = nullptr, *right = nullptr, *parent;
private:
int height = 0;
int updateHeight();
template <typename U>
friend int getHeight(BstNode<U> *node);
template <typename U>
friend class BST;
template <typename U>
friend class AvlTree;
};
template <typename T>
class BST
{
public:
BST();
explicit BST(T root_key);
BST(const BST<T> &other);
BST<T> &operator=(const BST<T> &other);
~BST();
virtual BstNode<T> *insert(T key);
BstNode<T> *getNode(const T &key);
const BstNode<T> *search(const T &key) const;
BstNode<T> *minNode();
BstNode<T> *maxNode();
T min();
T max();
virtual int height() const;
int placeInOrder(T array[], int begin) const;
void printInOrder() const;
void deleteNode(BstNode<T> *node);
static void deleteSubtree(BstNode<T> *node);
protected:
BstNode<T> *root;
};
template <typename T>
class AvlTree : public BST<T>
{
public:
AvlTree();
explicit AvlTree(T key);
BstNode<T> *insert(T key);
int height() const;
private:
void rightRotate(BstNode<T> *node);
void leftRotate(BstNode<T> *node);
void rebalance(BstNode<T> *node);
};
template <typename T>
int getHeight(BstNode<T> *node)
{
// if a node is null, we define its height as -1
return node != nullptr ? node->height : -1;
}
template <typename T>
T max(T a, T b)
{
return a > b ? a : b;
}
template <typename T>
void swap(T &a, T &b)
{
T c = a;
a = b;
b = c;
}
template <typename T>
BstNode<T>::BstNode(T key, BstNode<T> *parent) : key{std::move(key)}, parent{parent}
{}
template <typename T>
BstNode<T>::BstNode(const BstNode<T> &other) : key{other.key},
height{other.height}, parent{nullptr}
{
if (other.left != nullptr)
{
left = new BstNode<T>{*(other.left)};
left->parent = this;
}
if (other.right != nullptr)
{
right = new BstNode<T>{*(other.right)};
right->parent = this;
}
}
template <typename T>
BstNode<T> &BstNode<T>::operator=(BstNode<T> &other)
{
key = other.key;
height = other.height;
deleteChildren();
if (other.left != nullptr)
{
left = new BstNode<T>{other.right->key};
left->parent = this;
}
if (other.right != nullptr)
{
right = new BstNode<T>{*(other.right)};
right->parent = this;
}
return *this;
}
template <typename T>
BstNode<T>::~BstNode()
{
deleteChildren();
if (parent != nullptr)
{
if (parent->left == this)
parent->left = nullptr;
else
parent->right = nullptr;
}
}
template <typename T>
void BstNode<T>::deleteChildren()
{
if (left != nullptr)
{
left->deleteChildren();
delete left;
left = nullptr;
}
if (right != nullptr)
{
right->deleteChildren();
delete right;
right = nullptr;
}
}
template <typename T>
BstNode<T> *BstNode<T>::insert(T key)
{
if (key < this->key)
{
if (left == nullptr)
{
left = new BstNode<T>{std::move(key), this};
return left;
}
else
return left->insert(std::move(key));
}
else
{
if (right == nullptr)
{
right = new BstNode<T>{std::move(key), this};
return right;
}
else
return right->insert(std::move(key));
}
}
template <typename T>
BstNode<T> *BstNode<T>::getNode(const T &key)
{
if (key < this->key)
{
if (left == nullptr)
return nullptr;
return left->getNode(key);
}
else if (key > this->key)
{
if (right == nullptr)
return nullptr;
return right->getNode(key);
}
else // key == this->key
return this;
}
template <typename T>
const BstNode<T> *BstNode<T>::search(const T &key) const
{
if (key < this->key)
{
if (left == nullptr)
return nullptr;
return left->search(key);
}
else if (key > this->key)
{
if (right == nullptr)
return nullptr;
return right->search(key);
}
else // key == this->key
return this;
}
template <typename T>
BstNode<T> *BstNode<T>::findMin()
{
auto *node = this;
while (node->left != nullptr)
node = node->left;
return node;
}
template <typename T>
BstNode<T> *BstNode<T>::findMax()
{
auto *node = this;
while (node->right != nullptr)
node = node->right;
return node;
}
template <typename T>
BstNode<T> *BstNode<T>::nextLarger()
{
if (right != nullptr)
return right->findMin();
auto *node = this;
while ((node->parent != nullptr) && (node == node->parent->right))
node = node->parent;
return node->parent;
}
template <typename T>
void BstNode<T>::printInOrder() const
{
if (left != nullptr)
left->printInOrder();
std::cout << key << ' ';
if (right != nullptr)
right->printInOrder();
}
template <typename T>
int BstNode<T>::placeInOrder(T array[], int begin) const
{
if (left != nullptr)
begin = left->placeInOrder(array, begin);
array[begin++] = key;
if (right != nullptr)
begin = right->placeInOrder(array, begin);
return begin;
}
template <typename T>
int BstNode<T>::findHeight() const
{
int left_height = -1;
int right_height = -1;
if (left != nullptr)
left_height = left->findHeight();
if (right != nullptr)
right_height = right->findHeight();
return max(left_height, right_height) + 1;
}
template <typename T>
int BstNode<T>::updateHeight()
{
height = max(getHeight(this->left), getHeight(this->right)) + 1;
return height;
}
template <typename T>
BST<T>::BST() : root{nullptr}
{}
template <typename T>
BST<T>::BST(T root_key)
{
root = new BstNode<T>{root_key};
}
template <typename T>
BST<T>::BST(const BST<T> &other)
{
root = new BstNode<T>{*(other.root)};
}
template <typename T>
BST<T> &BST<T>::operator=(const BST<T> &other)
{
deleteSubtree(root);
root = new BstNode<T>{*(other.root)};
return *this;
}
template <typename T>
BST<T>::~BST()
{
deleteSubtree(root);
}
template <typename T>
void BST<T>::deleteSubtree(BstNode<T> *node)
{
if (node->parent != nullptr)
{
if (node->parent->left == node)
node->parent->left = nullptr;
else
node->parent->right = nullptr;
}
delete node;
}
template <typename T>
BstNode<T> *BST<T>::insert(T key)
{
return root->insert(std::move(key));
}
template <typename T>
BstNode<T> *BST<T>::getNode(const T &key)
{
return root->getNode(key);
}
template <typename T>
const BstNode<T> *BST<T>::search(const T &key) const
{
return root->search(key);
}
template <typename T>
BstNode<T> *BST<T>::minNode()
{
return root->findMin();
}
template <typename T>
BstNode<T> *BST<T>::maxNode()
{
return root->findMax();
}
template <typename T>
T BST<T>::min()
{
return minNode()->key;
}
template <typename T>
T BST<T>::max()
{
return maxNode()->key;
}
template <typename T>
int BST<T>::height() const
{
return root->findHeight();
}
template <typename T>
int BST<T>::placeInOrder(T array[], int begin) const
{
return root->placeInOrder(array, begin);
}
template <typename T>
void BST<T>::printInOrder() const
{
root->printInOrder();
}
template <typename T>
void BST<T>::deleteNode(BstNode<T> *node)
{
BstNode<T> *new_node;
if ((node->left != nullptr) && (node->right == nullptr))
new_node = node->left;
else if ((node->left == nullptr) && (node->right != nullptr))
new_node = node->right;
else if ((node->left == nullptr) && (node->right == nullptr))
{
deleteSubtree(node);
return;
}
else
{
new_node = node->nextLarger();
swap(node->key, new_node->key);
deleteNode(new_node);
return;
}
new_node->parent = node->parent;
if (new_node->parent != nullptr)
{
if (node->parent->left == node)
new_node->parent->left = new_node;
else
new_node->parent->right = new_node;
}
else
root = new_node;
node->parent = nullptr;
node->left = nullptr;
node->right = nullptr;
delete node;
}
template <typename T>
AvlTree<T>::AvlTree()
{}
template <typename T>
AvlTree<T>::AvlTree(T key) : BST<T>{std::move(key)}
{}
template <typename T>
BstNode<T> *AvlTree<T>::insert(T key)
{
if (this->root == nullptr)
{
this->root = new BstNode<T>{std::move(key)};
return this->root;
}
auto *node = this->root->insert(std::move(key));
rebalance(node);
return node;
}
template <typename T>
void AvlTree<T>::leftRotate(BstNode<T> *node)
{
auto *temp = node->right;
temp->parent = node->parent;
if (node->parent == nullptr)
this->root = temp;
else
{
if (node->parent->left == node)
node->parent->left = temp;
else
node->parent->right = temp;
}
node->right = temp->left;
if (node->right != nullptr)
node->right->parent = node;
temp->left = node;
node->parent = temp;
node->updateHeight();
temp->updateHeight();
}
template <typename T>
void AvlTree<T>::rightRotate(BstNode<T> *node)
{
auto *temp = node->left;
temp->parent = node->parent;
if (node->parent == nullptr)
this->root = temp;
else
{
if (node->parent->left == node)
node->parent->left = temp;
else
node->parent->right = temp;
}
node->left = temp->right;
if (node->left != nullptr)
node->left->parent = node;
temp->right = node;
node->parent = temp;
node->updateHeight();
temp->updateHeight();
}
template <typename T>
void AvlTree<T>::rebalance(BstNode<T> *node)
{
do
{
node->updateHeight();
if (getHeight(node->left) > getHeight(node->right) + 1)
{
if (getHeight(node->left->left) >= getHeight(node->left->right))
rightRotate(node);
else
{
leftRotate(node->left);
rightRotate(node);
}
}
else if (getHeight(node->right) > getHeight(node->left) + 1)
{
if (getHeight(node->right->right) >= getHeight(node->right->left))
leftRotate(node);
else
{
rightRotate(node->right);
leftRotate(node);
}
}
node = node->parent;
}
while(node != nullptr);
}
template <typename T>
int AvlTree<T>::height() const
{
return getHeight(this->root);
}
} // namespace bst
Unit Tests
#include <iostream>
#include <cassert>
#include "bst.hpp"
void insertionTest()
{
auto *node = new bst::BstNode<int>{4};
assert(node->insert(2) == node->left);
assert(node->insert(3) == node->left->right);
assert(node->insert(1) == node->left->left);
assert(node->insert(6) == node->right);
assert(node->insert(7) == node->right->right);
assert(node->insert(5) == node->right->left);
delete node;
}
bst::BST<int> makeBst()
{
auto tree = bst::BST<int>{1};
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(6);
tree.insert(7);
tree.insert(8);
tree.insert(9);
return tree;
}
bst::AvlTree<int> makeAvl()
{
auto tree = bst::AvlTree<int>{1};
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(6);
tree.insert(7);
tree.insert(8);
tree.insert(9);
return tree;
}
void testHeight(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
assert(tree1.height() == 8);
assert(tree2.height() == 3);
}
void testMin(bst::BST<int> &tree1, bst::AvlTree<int> &tree2)
{
assert(tree1.min() == 1);
assert(tree2.min() == 1);
}
void testMax(bst::BST<int> &tree1, bst::AvlTree<int> &tree2)
{
assert(tree1.max() == 9);
assert(tree2.max() == 9);
}
void testSearch(const bst::AvlTree<int> &tree)
{
assert(tree.search(3)->key == 3);
assert(tree.search(10) == nullptr);
}
void testCopy(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
auto tree3 = tree1;
auto tree4 = tree2;
assert(tree1.search(8)->key == tree3.search(8)->key);
assert(tree1.search(8) != tree3.search(8));
assert(tree2.search(8)->key == tree4.search(8)->key);
assert(tree2.search(8) != tree4.search(8));
}
void testAssignment(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
auto tree3 = bst::BST<int>{0};
auto tree4 = bst::AvlTree<int>{0};
tree3 = tree1;
tree4 = tree2;
assert(tree1.search(8)->key == tree3.search(8)->key);
assert(tree1.search(8) != tree3.search(8));
assert(tree2.search(8)->key == tree4.search(8)->key);
assert(tree2.search(8) != tree4.search(8));
}
void testInOrderTraversal(const bst::AvlTree<int> &tree)
{
int a[9];
tree.placeInOrder(a, 0);
for (int i = 0; i < 9; ++i)
assert(a[i] == i + 1);
}
void testNextLarger(bst::AvlTree<int> &tree)
{
assert(tree.getNode(4)->nextLarger()->key == 5);
}
void testDelete(bst::BST<int> tree1, bst::AvlTree<int> tree2)
{
tree1.deleteNode(tree1.getNode(4));
assert(tree1.search(4) == nullptr);
assert(tree1.search(1)->key == 1);
tree2.deleteNode(tree2.getNode(4));
assert(tree2.search(4) == nullptr);
assert(tree2.search(1)->key == 1);
}
void testBundle()
{
insertionTest();
auto tree1 = makeBst();
auto tree2 = makeAvl();
testHeight(tree1, tree2);
testMin(tree1, tree2);
testMax(tree1, tree2);
testSearch(tree2);
testCopy(tree1, tree2);
testAssignment(tree1, tree2);
testInOrderTraversal(tree2);
testNextLarger(tree2);
testDelete(tree1, tree2);
}
int main()
{
testBundle();
}
Example Use Case
#include <iostream>
#include <random>
#include <chrono>
#include "bst.hpp"
static std::mt19937 random_engine(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
void fillRandomly(int array[], int begin, int end)
{
std::uniform_int_distribution<int> range(0, 99);
for (int i = begin; i <= end; ++i)
array[i] = range(random_engine);
}
void avlSort(int array[], int begin, int end)
{
bst::AvlTree<int> tree;
for (int i = begin; i <= end; ++i)
tree.insert(array[i]);
tree.placeInOrder(array, begin);
}
void printArray(int array[], int begin, int end)
{
for (int i = begin; i <= end; ++i)
std::cout << array[i] << ' ';
}
int main()
{
int array[100];
fillRandomly(array, 0, 99);
printArray(array, 0, 99);
std::cout << "\n\n";
avlSort(array, 0, 99);
printArray(array, 0, 99);
}
Notes:
I didn't use smart pointers for two reasons; First, I wanted to practice using raw pointers. Second, since I'm using parent pointers, I had to use
std::shared_ptrfor smart pointers which adds some overhead. (I heard it's possible to implement AVL trees without parent pointers, but I can't even imagine how deleting a node is possible without using parent pointers)BST::min,BST::maxandBstNode::nextLargershould beconst, but I couldn't find an elegant way to define them that way while not messing upBST::deleteNode.
Any suggestion for additional functionality and improving readability (e.g. using modern features of c++) are welcome.
deleteNodefunction has some problems. My tests weren't rigorous enough to catch the errors. \$\endgroup\$findHeight()with the ternary operator like this:int left_height = left ? left->findHeight() : -1;\$\endgroup\$