8
\$\begingroup\$

After not programming in c++ for some time, I decided to write an AVL tree implementation to get back in shape (I wasn't that good anyway. Still an amateur).

Header File

bst.hpp

#pragma once


#include <iostream>
#include <utility>


namespace bst
{
    template <typename T>
    class BstNode
    {
    public:
        explicit BstNode(T key, BstNode<T> *parent = nullptr);
        BstNode(const BstNode<T> &other);
        BstNode<T> &operator=(BstNode<T> &other);
        ~BstNode();

        BstNode<T> *insert(T key);
        BstNode<T> *getNode(const T &key);
        const BstNode<T> *search(const T &key) const;
        BstNode<T> *findMin();
        BstNode<T> *findMax();
        BstNode<T> *nextLarger();
        
        void printInOrder() const;
        int placeInOrder(T array[], int begin) const;
        int findHeight() const;
        void deleteChildren();

        T key;
        BstNode<T> *left = nullptr, *right = nullptr, *parent;
        
    private:
        int height = 0;
        int updateHeight();

        template <typename U>
        friend int getHeight(BstNode<U> *node);
        template <typename U>
        friend class BST;
        template <typename U>
        friend class AvlTree;
    };

    template <typename T>
    class BST
    {
    public:
        BST();
        explicit BST(T root_key);
        BST(const BST<T> &other);
        BST<T> &operator=(const BST<T> &other);
        ~BST();

        virtual BstNode<T> *insert(T key);
        BstNode<T> *getNode(const T &key);
        const BstNode<T> *search(const T &key) const;
        BstNode<T> *minNode();
        BstNode<T> *maxNode();
        T min();
        T max();
        virtual int height() const;
        int placeInOrder(T array[], int begin) const;
        void printInOrder() const;
        void deleteNode(BstNode<T> *node);
        static void deleteSubtree(BstNode<T> *node);

    protected:
        BstNode<T> *root;
    };

    template <typename T>
    class AvlTree : public BST<T>
    {
    public:
        AvlTree();
        explicit AvlTree(T key);

        BstNode<T> *insert(T key);
        int height() const;
    
    private:
        void rightRotate(BstNode<T> *node);
        void leftRotate(BstNode<T> *node);
        void rebalance(BstNode<T> *node);
    };


    template <typename T>
    int getHeight(BstNode<T> *node)
    {
        // if a node is null, we define its height as -1
        return node != nullptr ? node->height : -1;
    }

    template <typename T>
    T max(T a, T b)
    {
        return a > b ? a : b;
    }

    template <typename T>
    void swap(T &a, T &b)
    {
        T c = a;
        a = b;
        b = c;
    }


    template <typename T>
    BstNode<T>::BstNode(T key, BstNode<T> *parent) : key{std::move(key)}, parent{parent}
    {}

    template <typename T>
    BstNode<T>::BstNode(const BstNode<T> &other) : key{other.key},
    height{other.height}, parent{nullptr}
    {
        if (other.left != nullptr)
        {
            left = new BstNode<T>{*(other.left)};
            left->parent = this;
        }
        if (other.right != nullptr)
        {
            right = new BstNode<T>{*(other.right)};
            right->parent = this;
        }
    }

    template <typename T>
    BstNode<T> &BstNode<T>::operator=(BstNode<T> &other)
    {
        key = other.key;
        height = other.height;
        deleteChildren();
        if (other.left != nullptr)
        {
            left = new BstNode<T>{other.right->key};
            left->parent = this;
        }
        if (other.right != nullptr)
        {
            right = new BstNode<T>{*(other.right)};
            right->parent = this;
        }
        return *this;
    }

    template <typename T>
    BstNode<T>::~BstNode()
    {
        deleteChildren();
        if (parent != nullptr)
        {
            if (parent->left == this)
                parent->left = nullptr;
            else
                parent->right = nullptr;
        }
    }

    template <typename T>
    void BstNode<T>::deleteChildren()
    {
        if (left != nullptr)
        {
            left->deleteChildren();
            delete left;
            left = nullptr;
        }
        if (right != nullptr)
        {
            right->deleteChildren();
            delete right;
            right = nullptr;
        }
    }

    template <typename T>
    BstNode<T> *BstNode<T>::insert(T key)
    {
        if (key < this->key)
        {
            if (left == nullptr)
            {
                left = new BstNode<T>{std::move(key), this};
                return left;
            }
            else
                return left->insert(std::move(key));
        }
        else
        {
            if (right == nullptr)
            {
                right = new BstNode<T>{std::move(key), this};
                return right;
            }
            else
                return right->insert(std::move(key));
        }
    }

    template <typename T>
    BstNode<T> *BstNode<T>::getNode(const T &key)
    {
        if (key < this->key)
        {
            if (left == nullptr)
                return nullptr;
            return left->getNode(key);
        }
        else if (key > this->key)
        {
            if (right == nullptr)
                return nullptr;
            return right->getNode(key);
        }
        else // key == this->key
            return this;
    }

    template <typename T>
    const BstNode<T> *BstNode<T>::search(const T &key) const
    {
        if (key < this->key)
        {
            if (left == nullptr)
                return nullptr;
            return left->search(key);
        }
        else if (key > this->key)
        {
            if (right == nullptr)
                return nullptr;
            return right->search(key);
        }
        else // key == this->key
            return this;
    }

    template <typename T>
    BstNode<T> *BstNode<T>::findMin()
    {
        auto *node = this;
        while (node->left != nullptr)
            node = node->left;
        return node;
    }

    template <typename T>
    BstNode<T> *BstNode<T>::findMax()
    {
        auto *node = this;
        while (node->right != nullptr)
            node = node->right;
        return node;
    }

    template <typename T>
    BstNode<T> *BstNode<T>::nextLarger()
    {
        if (right != nullptr)
            return right->findMin();
        
        auto *node = this;
        while ((node->parent != nullptr) && (node == node->parent->right))
            node = node->parent;
        return node->parent;
    }

    template <typename T>
    void BstNode<T>::printInOrder() const
    {
        if (left != nullptr)
            left->printInOrder();
        std::cout << key << ' ';
        if (right != nullptr)
            right->printInOrder();
    }

    template <typename T>
    int BstNode<T>::placeInOrder(T array[], int begin) const
    {
        if (left != nullptr)
            begin = left->placeInOrder(array, begin);
        array[begin++] = key;
        if (right != nullptr)
            begin = right->placeInOrder(array, begin);
        return begin;
    }

    template <typename T>
    int BstNode<T>::findHeight() const
    {
        int left_height = -1;
        int right_height = -1;
        if (left != nullptr)
            left_height = left->findHeight();
        if (right != nullptr)
            right_height = right->findHeight();
        
        return max(left_height, right_height) + 1;
    }

    template <typename T>
    int BstNode<T>::updateHeight()
    {
        height = max(getHeight(this->left), getHeight(this->right)) + 1;
        return height;
    }


    template <typename T>
    BST<T>::BST() : root{nullptr}
    {}

    template <typename T>
    BST<T>::BST(T root_key)
    {
        root = new BstNode<T>{root_key};
    }

    template <typename T>
    BST<T>::BST(const BST<T> &other)
    {
        root = new BstNode<T>{*(other.root)};
    }

    template <typename T>
    BST<T> &BST<T>::operator=(const BST<T> &other)
    {
        deleteSubtree(root);
        root = new BstNode<T>{*(other.root)};
        return *this;
    }

    template <typename T>
    BST<T>::~BST()
    {
        deleteSubtree(root);
    }

    template <typename T>
    void BST<T>::deleteSubtree(BstNode<T> *node)
    {
        if (node->parent != nullptr)
        {
            if (node->parent->left == node)
                node->parent->left = nullptr;
            else
                node->parent->right = nullptr;
        }
        delete node;
    }

    template <typename T>
    BstNode<T> *BST<T>::insert(T key)
    {
        return root->insert(std::move(key));
    }

    template <typename T>
    BstNode<T> *BST<T>::getNode(const T &key)
    {
        return root->getNode(key);
    }

    template <typename T>
    const BstNode<T> *BST<T>::search(const T &key) const
    {
        return root->search(key);
    }

    template <typename T>
    BstNode<T> *BST<T>::minNode()
    {
        return root->findMin();
    }

    template <typename T>
    BstNode<T> *BST<T>::maxNode()
    {
        return root->findMax();
    }

    template <typename T>
    T BST<T>::min()
    {
        return minNode()->key;
    }

    template <typename T>
    T BST<T>::max()
    {
        return maxNode()->key;
    }

    template <typename T>
    int BST<T>::height() const
    {
        return root->findHeight();
    }

    template <typename T>
    int BST<T>::placeInOrder(T array[], int begin) const
    {
        return root->placeInOrder(array, begin);
    } 

    template <typename T>
    void BST<T>::printInOrder() const
    {
        root->printInOrder();
    }

    template <typename T>
    void BST<T>::deleteNode(BstNode<T> *node)
    {
        BstNode<T> *new_node;
        if ((node->left != nullptr) && (node->right == nullptr))
            new_node = node->left;
        else if ((node->left == nullptr) && (node->right != nullptr))
            new_node = node->right;
        else if ((node->left == nullptr) && (node->right == nullptr))
        {
            deleteSubtree(node);
            return;
        }
        else
        {
            new_node = node->nextLarger();
            swap(node->key, new_node->key);
            deleteNode(new_node);
            return;
        }

        new_node->parent = node->parent;
        if (new_node->parent != nullptr)
        {
            if (node->parent->left == node)
                new_node->parent->left = new_node;
            else
                new_node->parent->right = new_node;
        }
        else
            root = new_node;
        
        node->parent = nullptr;
        node->left = nullptr;
        node->right = nullptr;
        delete node;
    }


    template <typename T>
    AvlTree<T>::AvlTree()
    {}

    template <typename T>
    AvlTree<T>::AvlTree(T key) : BST<T>{std::move(key)}
    {}

    template <typename T>
    BstNode<T> *AvlTree<T>::insert(T key)
    {
        if (this->root == nullptr)
        {
            this->root = new BstNode<T>{std::move(key)};
            return this->root;
        } 
        auto *node = this->root->insert(std::move(key));
        rebalance(node);
        return node;
    }

    template <typename T>
    void AvlTree<T>::leftRotate(BstNode<T> *node)
    {
        auto *temp = node->right;
        temp->parent = node->parent;
        if (node->parent == nullptr)
            this->root = temp;
        else
        {
            if (node->parent->left == node)
                node->parent->left = temp;
            else
                node->parent->right = temp;
        }

        node->right = temp->left;
        if (node->right != nullptr)
            node->right->parent = node;
        
        temp->left = node;
        node->parent = temp;

        node->updateHeight();
        temp->updateHeight();
    }

    template <typename T>
    void AvlTree<T>::rightRotate(BstNode<T> *node)
    {
        auto *temp = node->left;
        temp->parent = node->parent;
        if (node->parent == nullptr)
            this->root = temp;
        else
        {
            if (node->parent->left == node)
                node->parent->left = temp;
            else
                node->parent->right = temp;
        }

        node->left = temp->right;
        if (node->left != nullptr)
            node->left->parent = node;
        
        temp->right = node;
        node->parent = temp;

        node->updateHeight();
        temp->updateHeight();
    }

    template <typename T>
    void AvlTree<T>::rebalance(BstNode<T> *node)
    {
        do
        {
            node->updateHeight();
            if (getHeight(node->left) > getHeight(node->right) + 1)
            {
                if (getHeight(node->left->left) >= getHeight(node->left->right))
                    rightRotate(node);
                else
                {
                    leftRotate(node->left);
                    rightRotate(node);
                }
            }
            else if (getHeight(node->right) > getHeight(node->left) + 1)
            {
                if (getHeight(node->right->right) >= getHeight(node->right->left))
                    leftRotate(node);
                else
                {
                    rightRotate(node->right);
                    leftRotate(node);
                }
            }
            node = node->parent;
        }
        while(node != nullptr);
    }

    template <typename T>
    int AvlTree<T>::height() const
    {
        return getHeight(this->root);
    }

} // namespace bst

Unit Tests

#include <iostream>
#include <cassert>
#include "bst.hpp"

void insertionTest()
{
    auto *node = new bst::BstNode<int>{4};
    assert(node->insert(2) == node->left);
    assert(node->insert(3) == node->left->right);
    assert(node->insert(1) == node->left->left);
    assert(node->insert(6) == node->right);
    assert(node->insert(7) == node->right->right);
    assert(node->insert(5) == node->right->left);
    delete node;
}

bst::BST<int> makeBst()
{
    auto tree = bst::BST<int>{1};
    tree.insert(2);
    tree.insert(3);
    tree.insert(4);
    tree.insert(5);
    tree.insert(6);
    tree.insert(7);
    tree.insert(8);
    tree.insert(9);
    return tree;
}

bst::AvlTree<int> makeAvl()
{
    auto tree = bst::AvlTree<int>{1};
    tree.insert(2);
    tree.insert(3);
    tree.insert(4);
    tree.insert(5);
    tree.insert(6);
    tree.insert(7);
    tree.insert(8);
    tree.insert(9);
    return tree;
}

void testHeight(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
    assert(tree1.height() == 8);
    assert(tree2.height() == 3);
}

void testMin(bst::BST<int> &tree1, bst::AvlTree<int> &tree2)
{
    assert(tree1.min() == 1);
    assert(tree2.min() == 1);
}

void testMax(bst::BST<int> &tree1, bst::AvlTree<int> &tree2)
{
    assert(tree1.max() == 9);
    assert(tree2.max() == 9);
}

void testSearch(const bst::AvlTree<int> &tree)
{
    assert(tree.search(3)->key == 3);
    assert(tree.search(10) == nullptr);
}

void testCopy(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
    auto tree3 = tree1;
    auto tree4 = tree2;
    assert(tree1.search(8)->key == tree3.search(8)->key);
    assert(tree1.search(8) != tree3.search(8));
    assert(tree2.search(8)->key == tree4.search(8)->key);
    assert(tree2.search(8) != tree4.search(8));
}

void testAssignment(const bst::BST<int> &tree1, const bst::AvlTree<int> &tree2)
{
    auto tree3 = bst::BST<int>{0};
    auto tree4 = bst::AvlTree<int>{0};
    tree3 = tree1;
    tree4 = tree2;
    assert(tree1.search(8)->key == tree3.search(8)->key);
    assert(tree1.search(8) != tree3.search(8));
    assert(tree2.search(8)->key == tree4.search(8)->key);
    assert(tree2.search(8) != tree4.search(8));
}

void testInOrderTraversal(const bst::AvlTree<int> &tree)
{
    int a[9];
    tree.placeInOrder(a, 0);
    for (int i = 0; i < 9; ++i)
        assert(a[i] == i + 1);
}

void testNextLarger(bst::AvlTree<int> &tree)
{
    assert(tree.getNode(4)->nextLarger()->key == 5);
}

void testDelete(bst::BST<int> tree1, bst::AvlTree<int> tree2)
{
    tree1.deleteNode(tree1.getNode(4));
    assert(tree1.search(4) == nullptr);
    assert(tree1.search(1)->key == 1);

    tree2.deleteNode(tree2.getNode(4));
    assert(tree2.search(4) == nullptr);
    assert(tree2.search(1)->key == 1);
}

void testBundle()
{
    insertionTest();
    auto tree1 = makeBst();
    auto tree2 = makeAvl();
    testHeight(tree1, tree2);
    testMin(tree1, tree2);
    testMax(tree1, tree2);
    testSearch(tree2);
    testCopy(tree1, tree2);
    testAssignment(tree1, tree2);
    testInOrderTraversal(tree2);
    testNextLarger(tree2);
    testDelete(tree1, tree2);
}

int main()
{
    testBundle();
}

Example Use Case

#include <iostream>
#include <random>
#include <chrono>
#include "bst.hpp"

static std::mt19937 random_engine(
    std::chrono::high_resolution_clock::now().time_since_epoch().count());

void fillRandomly(int array[], int begin, int end)
{
    std::uniform_int_distribution<int> range(0, 99);
    for (int i = begin; i <= end; ++i)
        array[i] = range(random_engine);
}

void avlSort(int array[], int begin, int end)
{
    bst::AvlTree<int> tree;
    for (int i = begin; i <= end; ++i)
        tree.insert(array[i]);
    tree.placeInOrder(array, begin);
}

void printArray(int array[], int begin, int end)
{
    for (int i = begin; i <= end; ++i)
        std::cout << array[i] << ' ';
}

int main()
{
    int array[100];
    fillRandomly(array, 0, 99);
    printArray(array, 0, 99);
    std::cout << "\n\n";
    avlSort(array, 0, 99);
    printArray(array, 0, 99);
}

Notes:

  1. I didn't use smart pointers for two reasons; First, I wanted to practice using raw pointers. Second, since I'm using parent pointers, I had to use std::shared_ptr for smart pointers which adds some overhead. (I heard it's possible to implement AVL trees without parent pointers, but I can't even imagine how deleting a node is possible without using parent pointers)

  2. BST::min, BST::max and BstNode::nextLarger should be const, but I couldn't find an elegant way to define them that way while not messing up BST::deleteNode.

Any suggestion for additional functionality and improving readability (e.g. using modern features of c++) are welcome.

\$\endgroup\$
2
  • \$\begingroup\$ For future reference, my deleteNode function has some problems. My tests weren't rigorous enough to catch the errors. \$\endgroup\$ Commented Aug 30, 2020 at 11:46
  • \$\begingroup\$ When cast to boolean, a pointer results in true iff it's non-null, so you can simplify findHeight() with the ternary operator like this: int left_height = left ? left->findHeight() : -1; \$\endgroup\$ Commented 2 days ago

3 Answers 3

8
\$\begingroup\$

About inheritance

I would advice against using (public) inheritance of BST for AvlTree. Consider that with your code, the following is valid:

bst::AvlTree<int> tree{...};
bst::BST *base_ptr = &tree;
base_ptr->deleteNode(base_ptr->getNode(42));

Without any compiler error or warning, I deleted a node from the AVL tree without it being rebalanced. This might cause the AVL tree to lose its balance, and even worse, it might cause subsequent operations on tree to fail if they depend on the properties being held.

You could make more functions virtual, but it's easy to forget something here. It is better if you make it so you can't access the base class at all.

One way to achieve that is to use private inheritance, and selectively bring features from the base class into the public API of the derived class, like so:

template<typename T>
class AvlTree: private BST<T>
{
public:
    ...
    using BST::getNode;
    using BST::findMin;
    ...
};

Or use composition instead of inheritance:

template<typename T>
class AvlTree>
{
    BST<T> bst;
public:
    ...
    BstNode<T> *getNode(const T &key) {return bst.getNode(key);}
    BstNode<T> *findMin(const T &key) {return bst.findMin(key);}
    ...
};

Once you get rid of the base class, it doesn't make sense to have virtual functions anymore. It also gets rid of the overhead of a vtable.

Alternatively just don't use BST at all inside AvlTree, just rely on BstNode and its member functions. But here also, you want to prevent the ability to do the following:

bst::AvlTree<int> tree{...};
bst::BstNode *node = tree.getNode(...);
node->insert(42);

You can do that by making the member functions of BstNode private, and using a friend declaration to allow AvlTree access to the private members of BstNode. But that's also more complicated than necessary. Personally, I would get rid of the generic BstNode and BST, and just have a class AvlTree which inside declares a class Node, like so:

template <typename T>
class AvlTree {
    class Node {
        T key;
        int height;
        Node *left{};
        Node *right{};
        Node *parent{};
    };

    Node *root;

public:
    ...
};

And instead of having member functions like search() return a pointer to a node, return a pointer to the key itself. This will ensure the implementation of your nodes is completely hidden.

Use the same name for getNode() and search()

These two functions do exactly the same thing. The only difference is that search() operates on a const instance of a tree, and returns a const pointer to a node. You can just use the same name for those functions, the compiler is able to distinguish between the two cases. You can keep your code exactly as it is, just replace getNode with search, or the other way around as you like.

Avoid writing specialized convenience functions

I see functions like printInOrder() and placeInOrder(). While that might be convenient, they are quite specialized. What if instead of printing to std::cout I want to write to a file? Or instead of placing the items into an array, I want it in a std::vector? It is better to add something generic that makes it easy to do those things. For example, add an iterator class, and add begin() and end() member functions to AvlTree. This will allow you to then do the following:

bst:AvlTree<int> tree{...};

// Print all elements of the tree
for (auto key: tree)
    std::cout << key << ' ';

// Copy the elements into a vector
std::vector<int> vec(tree.begin(), tree.end());

Once you have iterators for your class, there is a whole slew of STL algorithms that can operator on your AVL tree, without you having to make custom implementations. This also brings me to:

Try to emulate other STL containers

Have a good look at other STL containers such as std::map, and try to mimic their interface as much as possible. Not only try to use the same names as STL for similar functions (for example, find() instead of search() or getNode()), but you can also find a lot of functions that you are missing, that might be very useful to have in a tree, like size(), empty(), erase(), and so on. If you keep the API the same as STL containers, there is less cognitive overhead, and makes it easier to swap an STL container for your AVL tree, and vice versa.

You'll also notice that a std::map has a key and a value type. Your AvlTree is more like a std::set. Both have their uses. Maybe you could make an AvlMap and an AvlSet?

Finally, the ordered STL containers allow you to specify a custom comparison operator. This is nice to have if the type of keys you want to store do not have a proper ordering of their own, or if I want to sort them in a different way than their natural order (consider for example wanting to store std::strings, but use case-insensitive comparison to order them).

Make use of standard library functions instead of reinventing them

You implemented bst::max() and bst::swap(), but those functions are already part of <algorithm>. Why not use the ones that come with the standard library?

Making more functions const

To ensure BST::min() and BST::max() and can be const, just make variants of findMin() and findMax() that are const (they can overload the non-const versions). For nextLarger(), you would indeed need to make two versions of it, one regular and the other const. Then again, nextLarger() is only used inside deleteNode(), so does it really need to be a separate, public function? I would make it private, and then it doesn't matter it doesn't have a const version. If you do want to keep it public, then you have to make a const and non-const version of nextLarger.

Note that you don't have to write two full implementations of each function that you want to have a const and non-const version for. You can do something like this instead:

const Node *AvlTree::Node::nextLarger() const {
    // full implementation here
    ...
}

Node *AvlTree::Node::nextLarger() {
    return const_cast<Node *>(const_cast<const Node *>(this)->nextLarger());
}

Even though it might look a bit shady, the above is valid C++: since the original objects were not const, const-casting them to const and then back again is allowed.

Since C++23, you can use an "explicit object parameter" instead:

template <typename T>
class AvlTree {
    class Node {
        …
        auto findMin(this auto&& self) {
            auto *node = &self;
            while (node->left != nullptr)
                node = node->left;
            return node;
        }
        …
    };
    …
};

It will work for both const and non-const (as well as volatile) nodes, and the autos inside will correctly deduce the constness.

\$\endgroup\$
0
2
\$\begingroup\$

Example use cases

There are a few opportunities here.

  • Don't use magic numbers. 100 and 99 are used with respect to the dimensions of your array. This dimension should be a named constant.
  • Don't use raw C-style arrays of int, but rather std::array<int>.
  • If you're not actually going to use anything other than the max dimensions of the array, there's no point in having your functions take a start and end index, and if you're using std::array there's no need to pass dimensions at all.
  • We could then use range-based for-loops.
void avlSort(std::array<int> array)
{
    bst::AvlTree<int> tree;
    for (const auto x : array)
        tree.insert(x);
    tree.placeInOrder(array, begin);
}
\$\endgroup\$
1
\$\begingroup\$

In rebalance(), establishing the height before potentially rebalancing looks wrong.
There are two if-else statements where both branches end the same.
Put that same end after the conditional statement and avoid code multiplication. (Same starts go before.)

    {
        int const balance = getHeight(node->right) - getHeight(node->left);
        if (balance < -1)
        {
            if (getHeight(node->left->left) < getHeight(node->left->right))
                leftRotate(node->left);
            rightRotate(node);
        }
        else if (1 < balance)
        {
            if (getHeight(node->right->right) < getHeight(node->right->left))
                rightRotate(node->right);
            leftRotate(node);
        }
        // else done rebalancing(?!)
        node->updateHeight();
        node = node->parent;
    }

Far as I remember a single rotate (if double) is enough for inserts. One can stop retracing the path to root after deletes after the first node not needing a rotate.
Above loop misses early out in both cases.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.