4
\$\begingroup\$

I am trying to create a Binary Search Tree, which has STL support (iterators etc.). This is my first attempt. Everything is unit tested using Catch2. Could you please review what I missed and what I could do better? This is not homework, I am trying to create a repository with unit tested C++ algorithms at https://gitlab.com/MartenBE/varia as a hobby and a way to teach myself C++.

Binary Search Tree:

#ifndef BINARY_SEARCH_TREE_H
#define BINARY_SEARCH_TREE_H

#include "nlohmann/json.hpp"

using json = nlohmann::json;

#include <cassert>
#include <memory>
#include <optional>
#include <stack>
#include <utility>

template <class K, class V>
class binary_search_tree
{
private: // Forward declarations
    class node;
    class const_iterator;

public:
    binary_search_tree() = default;
    virtual ~binary_search_tree() = default;

    binary_search_tree(const binary_search_tree& other)
    {
        clear();
        m_size = other.m_size;

        std::stack<std::tuple<const std::unique_ptr<node>*, std::unique_ptr<node>*, node*>> nodes_to_deep_copy;
        nodes_to_deep_copy.push({&(other.m_root), &m_root, nullptr});

        while (!nodes_to_deep_copy.empty())
        {
            auto [other_node_ptr, node_ptr, parent_ptr] = nodes_to_deep_copy.top();
            nodes_to_deep_copy.pop();

            if (*other_node_ptr)
            {
                *node_ptr = std::make_unique<node>((*other_node_ptr)->m_key, (*other_node_ptr)->m_value, parent_ptr);

                nodes_to_deep_copy.push(
                        {&((*other_node_ptr)->m_left_child), &((*node_ptr)->m_left_child), node_ptr->get()});

                nodes_to_deep_copy.push(
                        {&((*other_node_ptr)->m_right_child), &((*node_ptr)->m_right_child), node_ptr->get()});
            }
        }
    }

    binary_search_tree& operator=(const binary_search_tree& other)
    {
        binary_search_tree temp{other};
        swap(*this, temp);

        return *this;
    }

    binary_search_tree(binary_search_tree&& other) : m_root{std::move(other.m_root)}, m_size{other.m_size}
    {
        other.m_size = 0;
    }

    binary_search_tree& operator=(binary_search_tree&& other)
    {
        binary_search_tree temp{std::move(other)};
        swap(*this, temp);

        return *this;
    }

    friend void swap(binary_search_tree& first, binary_search_tree& second)
    {
        using std::swap;
        swap(first.m_root, second.m_root);
        swap(first.m_size, second.m_size);
    }

    std::optional<V> search(const K& key)
    {
        auto [node_ptr, parent_node_ptr] = search_node_ptr(key);

        if (!node_ptr)
        {
            return std::nullopt;
        }

        return node_ptr->m_value;
    }

    void add(const K& key, const V& value)
    {
        auto [node_ptr, parent_node_ptr] = search_node_ptr(key);
        assert(!node_ptr);

        std::unique_ptr<node>* place_of_new_node = nullptr;

        if (!parent_node_ptr)
        {
            place_of_new_node = &m_root;
        }
        else
        {
            if (key < parent_node_ptr->m_key)
            {
                place_of_new_node = &(parent_node_ptr->m_left_child);
            }
            else
            {
                place_of_new_node = &(parent_node_ptr->m_right_child);
            }
        }

        *place_of_new_node = std::make_unique<node>(key, value, parent_node_ptr);

        m_size++;
    }

    bool empty() const
    {
        return (size() == 0);
    }

    void clear()
    {
        m_root.reset();
        m_size = 0;
    }

    const K& find_minimum() const
    {
        assert(!empty());

        return find_minimum_node_ptr(*m_root)->m_key;
    }

    const K& find_maximum() const
    {
        assert(!empty());

        return find_maximum_node_ptr(*m_root)->m_key;
    }

    std::optional<K> find_predecessor(const K& key) const
    {
        assert(!empty());

        auto [node_ptr, parent_node_ptr] = search_node_ptr(key);
        assert(node_ptr);

        node* predecessor_ptr = find_predecessor_node_ptr(*node_ptr);

        if (!predecessor_ptr)
        {
            return std::nullopt;
        }

        return predecessor_ptr->m_key;
    }

    std::optional<K> find_successor(const K& key) const
    {
        assert(!empty());

        auto [node_ptr, parent_node_ptr] = search_node_ptr(key);
        assert(node_ptr);

        node* successor_ptr = find_successor_node_ptr(*node_ptr);

        if (!successor_ptr)
        {
            return std::nullopt;
        }

        return successor_ptr->m_key;
    }

    void remove(const K& key)
    {
        assert(!empty());

        auto [node_ptr, parent_node_ptr] = search_node_ptr(key);
        assert(node_ptr);

        remove(*node_ptr);

        m_size--;
    }

    int size() const
    {
        return m_size;
    }

    json to_json() const
    {
        if (empty())
        {
            return json{};
        }

        return m_root->to_json();
    }

    friend bool operator==(const binary_search_tree& lhs, const binary_search_tree& rhs)
    {
        if (lhs.empty() && rhs.empty())
        {
            return true;
        }

        if (lhs.size() != rhs.size())
        {
            return false;
        }

        std::stack<std::pair<node*, node*>> nodes_to_check;
        nodes_to_check.push({lhs.m_root.get(), rhs.m_root.get()});

        while (!nodes_to_check.empty())
        {
            std::pair<node*, node*> nodes = nodes_to_check.top();
            nodes_to_check.pop();
            node* node_lhs = nodes.first;
            node* node_rhs = nodes.second;

            if ((!node_lhs && node_rhs) || (node_lhs && !node_rhs))
            {
                return false;
            }

            if (node_lhs && node_rhs)
            {
                if (node_lhs->m_key != node_rhs->m_key)
                {
                    return false;
                }

                nodes_to_check.push({node_lhs->m_left_child.get(), node_rhs->m_left_child.get()});
                nodes_to_check.push({node_lhs->m_right_child.get(), node_rhs->m_right_child.get()});
            }
        }

        return true;
    }

    friend bool operator!=(const binary_search_tree& lhs, const binary_search_tree& rhs)
    {
        return !(lhs == rhs);
    }

    friend std::ostream& operator<<(std::ostream& os, const binary_search_tree& bt)
    {
        os << bt.to_json().dump(4);

        return os;
    }

    const_iterator begin()
    {
        if (empty())
        {
            return end();
        }

        return const_iterator{(*this), find_minimum_node_ptr(*m_root)};
    }

    const_iterator end()
    {
        return const_iterator{(*this), nullptr};
    }

private:
    std::pair<node*, node*> search_node_ptr(const K& key) const
    {
        node* current_node = m_root.get();
        node* parent_node = nullptr;

        while (current_node && (current_node->m_key != key))
        {
            parent_node = current_node;

            if (key < current_node->m_key)
            {
                current_node = (current_node->m_left_child).get();
            }
            else
            {
                current_node = (current_node->m_right_child).get();
            }
        }

        return std::pair<node*, node*>{current_node, parent_node};
    }

    node* find_minimum_node_ptr(node& root) const
    {
        node* current_node = &root;
        while (current_node->m_left_child)
        {
            current_node = (current_node->m_left_child).get();
        }

        return current_node;
    }

    node* find_maximum_node_ptr(node& root) const
    {
        node* current_node = &root;
        while (current_node->m_right_child)
        {
            current_node = (current_node->m_right_child).get();
        }

        return current_node;
    }

    node* find_predecessor_node_ptr(node& root) const
    {
        if (root.m_left_child)
        {
            return find_maximum_node_ptr(*(root.m_left_child));
        }
        else
        {
            node* current_node_ptr = &root;
            node* parent_ptr = current_node_ptr->m_parent;
            while (parent_ptr && ((parent_ptr->m_left_child).get() == current_node_ptr))
            {
                current_node_ptr = current_node_ptr->m_parent;
                parent_ptr = parent_ptr->m_parent;
            }

            return parent_ptr;
        }
    }

    node* find_successor_node_ptr(node& root) const
    {
        if (root.m_right_child)
        {
            return find_minimum_node_ptr(*(root.m_right_child));
        }
        else
        {
            node* current_node_ptr = &root;
            node* parent_ptr = current_node_ptr->m_parent;
            while (parent_ptr && ((parent_ptr->m_right_child).get() == current_node_ptr))
            {
                current_node_ptr = current_node_ptr->m_parent;
                parent_ptr = parent_ptr->m_parent;
            }

            return parent_ptr;
        }
    }

    void remove(node& root)
    {
        node* node_ptr = &root;

        std::unique_ptr<node>* node_owner_ptr = get_owner_pointer(node_ptr);

        if (!(node_ptr->m_left_child) && !(node_ptr->m_right_child))
        {
            node_owner_ptr->reset();
        }
        else if ((node_ptr->m_left_child) && (node_ptr->m_right_child))
        {
            node* replacement_ptr = find_successor_node_ptr(*node_ptr);

            node_ptr->m_key = std::move(replacement_ptr->m_key);
            node_ptr->m_value = std::move(replacement_ptr->m_value);

            remove(*replacement_ptr);
        }
        else
        {
            node* parent = node_ptr->m_parent;

            if (node_ptr->m_left_child)
            {
                *node_owner_ptr = std::move(node_ptr->m_left_child);
            }
            else
            {
                *node_owner_ptr = std::move(node_ptr->m_right_child);
            }

            (*node_owner_ptr)->m_parent = parent;
        }
    }

    std::unique_ptr<node>* get_owner_pointer(node* node_ptr)
    {
        if (node_ptr == m_root.get())
        {
            return &m_root;
        }

        if ((node_ptr->m_parent->m_left_child).get() == node_ptr)
        {
            return &(node_ptr->m_parent->m_left_child);
        }
        else
        {
            return &(node_ptr->m_parent->m_right_child);
        }
    }

    class node
    {
    public:
        node(const K& key, const V& value, node* parent) : m_key{key}, m_value{value}, m_parent{parent}
        {
        }

        node(const node& other) = delete;
        node& operator=(const node& other) = delete;
        node(node&& other) = default;
        node& operator=(node&& other) = default;
        virtual ~node() = default;

        friend void swap(node& first, node& second)
        {
            swap(first.m_key, second.m_key);
            swap(first.m_value, second.m_value);
            swap(first.m_left_child, second.m_left_child);
            swap(first.m_right_child, second.m_right_child);
            swap(first.m_parent, second.m_parent);
        }

        json to_json() const
        {
            json node_json;

            node_json["key"] = m_key;
            node_json["value"] = m_value;

            if (m_left_child)
            {
                node_json["left_child"] = m_left_child->to_json();
            }
            else
            {
                node_json["left_child"] = "";
            }

            if (m_right_child)
            {
                node_json["right_child"] = m_right_child->to_json();
            }
            else
            {
                node_json["right_child"] = "";
            }

            if (m_parent)
            {
                node_json["parent_key"] = m_parent->m_key;
            }
            else
            {
                node_json["parent_key"] = "";
            }

            return node_json;
        }

        K m_key;
        V m_value;

        std::unique_ptr<node> m_left_child = nullptr;
        std::unique_ptr<node> m_right_child = nullptr;
        node* m_parent = nullptr;
    };

    class const_iterator
    {
    public:
        using value_type = std::pair<K, V>;
        using pointer = const value_type*;
        using reference = const value_type&;
        using difference_type = int;
        using iterator_category = std::bidirectional_iterator_tag;

        const_iterator(const binary_search_tree& bst, node* node_ptr) : bst{bst}, node_ptr{node_ptr}
        {
        }

        const_iterator(const const_iterator& other) = default;
        const_iterator(const_iterator&& other) = default;
        const_iterator& operator=(const const_iterator& other) = default;
        const_iterator& operator=(const_iterator&& other) = default;
        virtual ~const_iterator() = default;

        const_iterator& operator++()
        {
            assert(node_ptr);

            node_ptr = bst.find_successor_node_ptr(*node_ptr);

            return *this;
        }

        const_iterator operator++(int)
        {
            const_iterator temp{*this};
            ++(*this);

            return temp;
        }

        const value_type operator*() const
        {
            assert(node_ptr);

            return value_type{node_ptr->m_key, node_ptr->m_value};
        }

        friend bool operator==(const const_iterator& lhs, const const_iterator& rhs)
        {
            return (lhs.node_ptr == rhs.node_ptr);
        }

        friend bool operator!=(const const_iterator& lhs, const const_iterator& rhs)
        {
            return !(lhs == rhs);
        }

    private:
        const binary_search_tree& bst;
        node* node_ptr;
    };

    std::unique_ptr<node> m_root = nullptr;
    int m_size = 0;
};

#endif

Unit tests:

#include "catch2/catch.hpp"

#include "binary-search-tree.hpp"

#include <sstream>

TEST_CASE("Create and fill a binary tree")
{
    binary_search_tree<int, int> bst_empty;

    binary_search_tree<int, int> bst_one_element;
    bst_one_element.add(10, 1010);

    binary_search_tree<int, int> bst;

    bst.add(8, 88);
    bst.add(3, 33);
    bst.add(10, 1010);
    bst.add(1, 11);
    bst.add(6, 66);
    bst.add(14, 1414);
    bst.add(4, 44);
    bst.add(7, 77);
    bst.add(13, 1313);

    SECTION("Constructor")
    {
        std::stringstream out;
        out << bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
            "key": 8,
            "left_child":
            {
                "key": 3,
                "left_child":
                {
                    "key": 1,
                    "left_child": "",
                    "parent_key": 3,
                    "right_child": "",
                    "value": 11
                },
                "parent_key": 8,
                "right_child":
                {
                    "key": 6,
                    "left_child":
                    {
                        "key": 4,
                        "left_child": "",
                        "parent_key": 6,
                        "right_child": "",
                        "value": 44
                    },
                    "parent_key": 3,
                    "right_child":
                    {
                        "key": 7,
                        "left_child": "",
                        "parent_key": 6,
                        "right_child": "",
                        "value": 77
                    },
                    "value": 66
                },
                "value": 33
            },
            "parent_key": "",
            "right_child":
            {
                "key": 10,
                "left_child": "",
                "parent_key": 8,
                "right_child":
                {
                    "key": 14,
                    "left_child":
                    {
                        "key": 13,
                        "left_child": "",
                        "parent_key": 14,
                        "right_child": "",
                        "value": 1313
                    },
                    "parent_key": 10,
                    "right_child": "",
                    "value": 1414
                },
                "value": 1010
            },
            "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Copy constructor")
    {
        binary_search_tree<int, int> another_bst{bst};
        another_bst.remove(8);

        std::stringstream out;
        out << another_bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 10,
                "left_child":
                {
                    "key": 3,
                    "left_child": {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 10,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 14,
                    "left_child":
                    {
                        "key": 13,
                        "left_child": "",
                        "parent_key": 14,
                        "right_child": "",
                        "value": 1313
                    },
                    "parent_key": 10,
                    "right_child": "",
                    "value": 1414
                },
                "value": 1010
            }
        )");

        REQUIRE(actual_json == expected_json);

        out = std::stringstream{};
        out << bst;

        actual_json = json::parse(out.str());
        expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child":
                        {
                            "key": 13,
                            "left_child": "",
                            "parent_key": 14,
                            "right_child": "",
                            "value": 1313
                        },
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Copy assignment")
    {
        binary_search_tree<int, int> another_bst = bst;
        another_bst.remove(8);

        std::stringstream out;
        out << another_bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 10,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 10,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 14,
                    "left_child":
                    {
                        "key": 13,
                        "left_child": "",
                        "parent_key": 14,
                        "right_child": "",
                        "value": 1313
                    },
                    "parent_key": 10,
                    "right_child": "",
                    "value": 1414
                },
                "value": 1010
            }
        )");

        REQUIRE(actual_json == expected_json);

        out = std::stringstream{};
        out << bst;

        actual_json = json::parse(out.str());
        expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child":
                        {
                            "key": 13,
                            "left_child": "",
                            "parent_key": 14,
                            "right_child": "",
                            "value": 1313
                        },
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Move constructor")
    {
        binary_search_tree<int, int> another_bst_empty{std::move(bst_empty)};
        binary_search_tree<int, int> another_bst_one_element{std::move(bst_one_element)};
        binary_search_tree<int, int> another_bst{std::move(bst)};

        REQUIRE(bst_empty.empty());
        REQUIRE(bst_one_element.empty());
        REQUIRE(bst.empty());

        REQUIRE(another_bst_empty.empty());

        std::stringstream out;
        out << another_bst_one_element;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 10,
                "left_child": "",
                "parent_key": "",
                "right_child": "",
                "value": 1010
            }
        )");

        REQUIRE(actual_json == expected_json);

        out = std::stringstream{};
        out << another_bst;

        actual_json = json::parse(out.str());
        expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child":
                        {
                            "key": 13,
                            "left_child": "",
                            "parent_key": 14,
                            "right_child": "",
                            "value": 1313
                        },
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Move assignment")
    {
        binary_search_tree<int, int> another_bst;

        another_bst = std::move(bst_empty);

        REQUIRE(bst_empty.empty());
        REQUIRE(another_bst.empty());

        another_bst = std::move(bst_one_element);

        REQUIRE(bst_one_element.empty());

        std::stringstream out;
        out << another_bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 10,
                "left_child": "",
                "parent_key": "",
                "right_child": "",
                "value": 1010
            }
        )");

        REQUIRE(actual_json == expected_json);

        another_bst = std::move(bst);

        REQUIRE(bst.empty());

        out = std::stringstream{};
        out << another_bst;

        actual_json = json::parse(out.str());
        expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child":
                        {
                            "key": 13,
                            "left_child": "",
                            "parent_key": 14,
                            "right_child": "",
                            "value": 1313
                        },
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Swap")
    {
    }

    SECTION("Search node")
    {
        auto search_result = bst.search(14);

        REQUIRE(search_result.has_value());
        REQUIRE(search_result.value() == 1414);

        search_result = bst.search(1337);

        REQUIRE(!search_result.has_value());
    }

    SECTION("Empty")
    {
        REQUIRE(bst_empty.empty());
        REQUIRE(!bst_one_element.empty());
        REQUIRE(!bst.empty());
    }

    SECTION("Clear")
    {
        bst_empty.clear();
        bst_one_element.clear();
        bst.clear();

        REQUIRE(bst_empty.empty());
        REQUIRE(bst_one_element.empty());
        REQUIRE(bst.empty());
    }

    SECTION("Find minimum")
    {
        REQUIRE(bst_one_element.find_minimum() == 10);
        REQUIRE(bst.find_minimum() == 1);
    }

    SECTION("Find maximum")
    {
        REQUIRE(bst_one_element.find_maximum() == 10);
        REQUIRE(bst.find_maximum() == 14);
    }

    SECTION("Find predecessor")
    {
        REQUIRE(!(bst_one_element.find_predecessor(10).has_value()));

        auto successor = bst.find_predecessor(1);
        REQUIRE(!(successor.has_value()));

        successor = bst.find_predecessor(3);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 1);

        successor = bst.find_predecessor(4);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 3);

        successor = bst.find_predecessor(6);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 4);

        successor = bst.find_predecessor(7);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 6);

        successor = bst.find_predecessor(8);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 7);

        successor = bst.find_predecessor(10);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 8);

        successor = bst.find_predecessor(13);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 10);

        successor = bst.find_predecessor(14);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 13);
    }

    SECTION("Find successor")
    {
        REQUIRE(!(bst_one_element.find_successor(10).has_value()));

        auto successor = bst.find_successor(1);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 3);

        successor = bst.find_successor(3);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 4);

        successor = bst.find_successor(4);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 6);

        successor = bst.find_successor(6);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 7);

        successor = bst.find_successor(7);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 8);

        successor = bst.find_successor(8);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 10);

        successor = bst.find_successor(10);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 13);

        successor = bst.find_successor(13);
        REQUIRE(successor.has_value());
        REQUIRE(successor.value() == 14);

        successor = bst.find_successor(14);
        REQUIRE(!(successor.has_value()));
    }

    SECTION("Remove node without children")
    {
        bst.remove(13);

        std::stringstream out;
        out << bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child": "",
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Remove node with only left child")
    {
        bst.remove(14);

        std::stringstream out;
        out << bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 13,
                        "left_child": "",
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1313
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Remove node with only right child")
    {
        bst.remove(10);

        std::stringstream out;
        out << bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 3,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 3,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child":
                        {
                            "key": 4,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 44
                        },
                        "parent_key": 3,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 33
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 14,
                    "left_child":
                    {
                        "key": 13,
                        "left_child": "",
                        "parent_key": 14,
                        "right_child": "",
                        "value": 1313
                    },
                    "parent_key": 8,
                    "right_child": "",
                    "value": 1414
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Remove node with 2 children")
    {
        bst.remove(3);

        std::stringstream out;
        out << bst;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 8,
                "left_child":
                {
                    "key": 4,
                    "left_child":
                    {
                        "key": 1,
                        "left_child": "",
                        "parent_key": 4,
                        "right_child": "",
                        "value": 11
                    },
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 6,
                        "left_child": "",
                        "parent_key": 4,
                        "right_child":
                        {
                            "key": 7,
                            "left_child": "",
                            "parent_key": 6,
                            "right_child": "",
                            "value": 77
                        },
                        "value": 66
                    },
                    "value": 44
                },
                "parent_key": "",
                "right_child":
                {
                    "key": 10,
                    "left_child": "",
                    "parent_key": 8,
                    "right_child":
                    {
                        "key": 14,
                        "left_child":
                        {
                            "key": 13,
                            "left_child": "",
                            "parent_key": 14,
                            "right_child": "",
                            "value": 1313
                        },
                        "parent_key": 10,
                        "right_child": "",
                        "value": 1414
                    },
                    "value": 1010
                },
                "value": 88
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Remove root node without children")
    {
        bst_one_element.remove(10);

        REQUIRE(bst_one_element.empty());
    }

    SECTION("Remove root node with only left child")
    {
        bst_one_element.add(5, 55);
        bst_one_element.remove(10);

        REQUIRE(bst_one_element.size() == 1);

        auto search_result = bst_one_element.search(5);

        REQUIRE(search_result.has_value());
        REQUIRE(search_result.value() == 55);
    }

    SECTION("Remove root node with only right child")
    {
        bst_one_element.add(15, 1515);
        bst_one_element.remove(10);

        REQUIRE(bst_one_element.size() == 1);

        auto search_result = bst_one_element.search(15);

        REQUIRE(search_result.has_value());
        REQUIRE(search_result.value() == 1515);
    }

    SECTION("Remove root node with 2 children")
    {
        bst_one_element.add(5, 55);
        bst_one_element.add(15, 1515);
        bst_one_element.remove(10);

        REQUIRE(bst_one_element.size() == 2);

        std::stringstream out;
        out << bst_one_element;

        auto actual_json = json::parse(out.str());
        auto expected_json = json::parse(R"(
            {
                "key": 15,
                "left_child":
                {
                    "key": 5,
                    "left_child": "",
                    "parent_key": 15,
                    "right_child": "",
                    "value": 55
                },
                "parent_key": "",
                "right_child": "",
                "value": 1515
            }
        )");

        REQUIRE(actual_json == expected_json);
    }

    SECTION("Size")
    {
        REQUIRE(bst_empty.size() == 0);
        REQUIRE(bst_one_element.size() == 1);
        REQUIRE(bst.size() == 9);
    }

    SECTION("Equals")
    {
        REQUIRE(bst_empty == bst_empty);
        REQUIRE(bst_one_element == bst_one_element);
        REQUIRE(bst == bst);

        binary_search_tree<int, int> bst_empty_2;

        binary_search_tree<int, int> bst_one_element_2;
        bst_one_element_2.add(10, 1010);

        binary_search_tree<int, int> bst_2;
        bst_2.add(8, 88);
        bst_2.add(3, 33);
        bst_2.add(10, 1010);
        bst_2.add(1, 11);
        bst_2.add(6, 66);
        bst_2.add(14, 1414);
        bst_2.add(4, 44);
        bst_2.add(7, 77);
        bst_2.add(13, 1313);

        REQUIRE(bst_empty == bst_empty_2);
        REQUIRE(bst_empty_2 == bst_empty);
        REQUIRE(bst_one_element == bst_one_element_2);
        REQUIRE(bst_one_element_2 == bst_one_element);
        REQUIRE(bst == bst_2);
        REQUIRE(bst_2 == bst);
    }

    SECTION("Not equals")
    {
        binary_search_tree<int, int> bst_other_root;
        bst_other_root.add(9, 99);
        bst_other_root.add(3, 33);
        bst_other_root.add(10, 1010);
        bst_other_root.add(1, 11);
        bst_other_root.add(6, 66);
        bst_other_root.add(14, 1414);
        bst_other_root.add(4, 44);
        bst_other_root.add(7, 77);
        bst_other_root.add(13, 1313);

        binary_search_tree<int, int> bst_other_intermediary_element;
        bst_other_intermediary_element.add(8, 88);
        bst_other_intermediary_element.add(3, 33);
        bst_other_intermediary_element.add(10, 1010);
        bst_other_intermediary_element.add(1, 11);
        bst_other_intermediary_element.add(5, 55);
        bst_other_intermediary_element.add(14, 1414);
        bst_other_intermediary_element.add(4, 44);
        bst_other_intermediary_element.add(7, 77);
        bst_other_intermediary_element.add(13, 1313);

        binary_search_tree<int, int> bst_other_leaf;
        bst_other_leaf.add(8, 88);
        bst_other_leaf.add(3, 33);
        bst_other_leaf.add(10, 1010);
        bst_other_leaf.add(1, 11);
        bst_other_leaf.add(6, 66);
        bst_other_leaf.add(14, 1414);
        bst_other_leaf.add(4, 44);
        bst_other_leaf.add(7, 77);
        bst_other_leaf.add(15, 1515);

        binary_search_tree<int, int> bst_less_elements;
        bst_less_elements.add(3, 33);
        bst_less_elements.add(10, 1010);
        bst_less_elements.add(1, 11);
        bst_less_elements.add(6, 66);
        bst_less_elements.add(14, 1414);
        bst_less_elements.add(4, 44);
        bst_less_elements.add(7, 77);
        bst_less_elements.add(13, 1313);

        binary_search_tree<int, int> bst_more_elements;
        bst_more_elements.add(8, 88);
        bst_more_elements.add(3, 33);
        bst_more_elements.add(10, 1010);
        bst_more_elements.add(1, 11);
        bst_more_elements.add(6, 66);
        bst_more_elements.add(14, 1414);
        bst_more_elements.add(4, 44);
        bst_more_elements.add(7, 77);
        bst_more_elements.add(13, 1313);
        bst_more_elements.add(15, 1515);

        REQUIRE(bst != bst_empty);
        REQUIRE(bst_empty != bst);
        REQUIRE(bst != bst_one_element);
        REQUIRE(bst_one_element != bst);
        REQUIRE(bst != bst_other_root);
        REQUIRE(bst_other_root != bst);
        REQUIRE(bst != bst_other_intermediary_element);
        REQUIRE(bst_other_intermediary_element != bst);
        REQUIRE(bst != bst_other_leaf);
        REQUIRE(bst_other_leaf != bst);
        REQUIRE(bst != bst_less_elements);
        REQUIRE(bst_less_elements != bst);
        REQUIRE(bst != bst_more_elements);
        REQUIRE(bst_more_elements != bst);
    }

    SECTION("Const iterator")
    {
        std::vector<std::pair<int, int>> expected_values = {
                {1, 11}, {3, 33}, {4, 44}, {6, 66}, {7, 77}, {8, 88}, {10, 1010}, {13, 1313}, {14, 1414}};

        std::vector<std::pair<int, int>> actual_values;

        for (const auto& element : bst)
        {
            actual_values.push_back(element);
        }

        REQUIRE(actual_values == expected_values);
    }
}

EDIT: There was a small error in the code for remove and its unit test. It didn't adjust the m_parent member of the node when replacing nodes with a single child.

\$\endgroup\$
2
  • \$\begingroup\$ You realize that std::map and std::set have the same characteristics (in terms of insert/delete/search) as a BST. \$\endgroup\$ Commented Jul 12, 2020 at 18:12
  • \$\begingroup\$ Yes, this is educational (to teach myself C++17 and newer). The goal is to advance into more difficult algorithms when I am sure of my basics. \$\endgroup\$ Commented Jul 12, 2020 at 18:19

1 Answer 1

2
\$\begingroup\$

The title claims that this is "STL compatible" - but I see no declaration of the usual member types: value_type, size_type and the rest.


I don't think this namespace alias should be at global scope in a header:

using json = nlohmann::json;

It would be more appropriate internal to the class body.


It's surprising that we use a relatively small, signed integer type to record the number of elements in the tree. I think that std::size_t is likely to be much more suitable.


Be aware that this innocuous-looking function hides a lot of complexity:

void clear()
{
    m_root.reset();
    m_size = 0;
}

When we reset() the root pointer, its destructor is called, which destroys its m_left_child and m_right_child, each of which might also have destructors. We can end up quite deeply recursed into these destructors in a large tree, which can result in stack overflow.

I would suggest walking the elements and freeing leaf nodes before their parents, iteratively until the root is reached. And then add a call to clear() into the tree's destructor, so that the default constructor doesn't have the same risk.


The copy constructor doesn't need to clear(), since it's initialised to empty state. And it shouldn't need the auxiliary storage of nodes_to_deep_copy - unless you're addressing a measured performance problem, stick to the simpler approach of for (auto node: other) emplace(node);.


The two overloads of operator=() can be replaced by a single function that takes its argument by value:

    binary_search_tree& operator=(binary_search_tree other)
    {
        swap(*this, other);
        return *this;
    }

The assert() in add() can't be justified. It's very easy for client code to violate the expectation, simply by adding the same key twice.


operator==() shouldn't need auxiliary storage. Given working iterators, it should be sufficient to std::ranges::equal() the contents of the two trees.


These very verbose assignments can be simplified:

        if (m_left_child)
        {
            node_json["left_child"] = m_left_child->to_json();
        }
        else
        {
            node_json["left_child"] = "";
        }

        if (m_right_child)
        {
            node_json["right_child"] = m_right_child->to_json();
        }
        else
        {
            node_json["right_child"] = "";
        }

        if (m_parent)
        {
            node_json["parent_key"] = m_parent->m_key;
        }
        else
        {
            node_json["parent_key"] = "";
        }

More readable IMO as

            node_json["left_child"]  = m_left_child  ? m_left_child->to_json()  : "";
            node_json["right_child"] = m_right_child ? m_right_child->to_json() : "";
            node_json["parent_key"]  = m_parent      ? json(m_parent->m_key)    : "";
\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.