diff --git a/include/hage/ds/avl_tree.hpp b/include/hage/ds/avl_tree.hpp index 8f19aee..6b7aadc 100644 --- a/include/hage/ds/avl_tree.hpp +++ b/include/hage/ds/avl_tree.hpp @@ -30,7 +30,7 @@ class AVLTree AVLTree() { - // This nodes parent is always the last node. + // This nodes parent is always the last node. It's value is always 0 m_end = get_node(); } @@ -43,11 +43,11 @@ class AVLTree m_nodes[m_end].m_parent = m_root; m_size++; - return { make_iterator(m_root), true }; + return { make_iterator(m_root), true }; } const auto [id, inserted] = internal_try_emplace(m_root, std::forward(key), std::forward(args)...); - return { make_iterator(id), inserted }; + return { make_iterator(id), inserted }; } [[nodiscard]] constexpr iterator find(const Key& key) @@ -56,51 +56,73 @@ class AVLTree if (id == -1) { return end(); } else { - return make_iterator(id); + return make_iterator(id); } } - [[nodiscard]] constexpr iterator end() noexcept { return make_iterator(m_end); } - [[nodiscard]] constexpr iterator begin() noexcept { return make_iterator(m_begin); } + [[nodiscard]] constexpr const_iterator find(const Key& key) const + { + auto id = internal_find(m_root, key); + if (id == -1) { + return end(); + } else { + return make_const_iterator(id); + } + } + + [[nodiscard]] constexpr bool contains(const Key& key) const { return internal_find(m_root, key) != -1; } + + [[nodiscard]] constexpr size_type size() const { return m_size; } + [[nodiscard]] constexpr bool empty() const { return size() == 0; } - [[nodiscard]] constexpr const_iterator end() const noexcept { return make_iterator(m_end); } - [[nodiscard]] constexpr const_iterator begin() const noexcept { return make_iterator(m_begin); } + constexpr void clear() + { + // Important to note here, that the end iterator, must be valid. + // we have the understanding that m_last is always 0. + // we are freeing from the back, so that the smallest nodes will be used first. + for (std::size_t i = m_nodes.size() - 1; 0 < i; i--) { + free_node(i); + } + m_size = 0; + m_root = -1; + } - [[nodiscard]] constexpr const_iterator cend() const noexcept { return make_iterator(m_end); } - [[nodiscard]] constexpr const_iterator cbegin() const noexcept { return make_iterator(m_begin); } + [[nodiscard]] constexpr iterator end() noexcept { return make_iterator(m_end); } + [[nodiscard]] constexpr iterator begin() noexcept { return make_iterator(m_begin); } + + [[nodiscard]] constexpr const_iterator end() const noexcept { return make_const_iterator(m_end); } + [[nodiscard]] constexpr const_iterator begin() const noexcept { return make_const_iterator(m_begin); } + + [[nodiscard]] constexpr const_iterator cend() const noexcept { return make_const_iterator(m_end); } + [[nodiscard]] constexpr const_iterator cbegin() const noexcept { return make_const_iterator(m_begin); } [[nodiscard]] constexpr reverse_iterator rbegin() noexcept { - return std::make_reverse_iterator(make_iterator(m_end)); + return std::make_reverse_iterator(make_iterator(m_end)); } [[nodiscard]] constexpr reverse_iterator rend() noexcept { - return std::make_reverse_iterator(make_iterator(m_begin)); + return std::make_reverse_iterator(make_iterator(m_begin)); } - [[nodiscard]] constexpr reverse_iterator rbegin() const noexcept + [[nodiscard]] constexpr const_reverse_iterator rbegin() const noexcept { - return std::make_reverse_iterator(make_iterator(m_end)); + return std::make_reverse_iterator(make_const_iterator(m_end)); } - [[nodiscard]] constexpr reverse_iterator rend() const noexcept + [[nodiscard]] constexpr const_reverse_iterator rend() const noexcept { - return std::make_reverse_iterator(make_iterator(m_begin)); + return std::make_reverse_iterator(make_const_iterator(m_begin)); } - [[nodiscard]] constexpr reverse_iterator crbegin() const noexcept + [[nodiscard]] constexpr const_reverse_iterator crbegin() const noexcept { - return std::make_reverse_iterator(make_iterator(m_end)); + return std::make_reverse_iterator(make_const_iterator(m_end)); } - [[nodiscard]] constexpr reverse_iterator crend() const noexcept + [[nodiscard]] constexpr const_reverse_iterator crend() const noexcept { - return std::make_reverse_iterator(make_iterator(m_begin)); + return std::make_reverse_iterator(make_const_iterator(m_begin)); } - [[nodiscard]] constexpr bool contains(const Key& key) const { return internal_find(m_root, key) != -1; } - - [[nodiscard]] constexpr size_type size() const { return m_size; } - [[nodiscard]] constexpr bool empty() const { return size() == 0; } - private: class Node { @@ -141,17 +163,18 @@ class AVLTree class Iterator { private: + using tree_type = std::conditional_t; using node_type = std::conditional_t; public: using iterator_category = std::bidirectional_iterator_tag; using value_type = node_type; using difference_type = std::int32_t; - using pointer = node_type*; - using reference = node_type&; + using pointer = value_type*; + using reference = value_type&; Iterator() = default; - Iterator(AVLTree* tree, node_id_type id) : m_tree{ tree }, m_id{ id } {} + Iterator(tree_type* tree, node_id_type id) : m_tree{ tree }, m_id{ id } {} constexpr reference operator*() const { return m_tree->m_nodes[m_id]; } constexpr pointer operator->() const { return &m_tree->m_nodes[m_id]; } @@ -243,17 +266,18 @@ class AVLTree private: friend class AVLTree; - AVLTree* m_tree{ nullptr }; + tree_type* m_tree{ nullptr }; node_id_type m_id{ -1 }; // TODO(rHermes): create an "end" iterator, that doesn't take up a node. This could be done if we track the // iterators a bit differently. }; - template - [[nodiscard]] constexpr Iterator make_iterator(node_id_type id) + [[nodiscard]] constexpr iterator make_iterator(node_id_type id) { return iterator{ this, id }; } + + [[nodiscard]] constexpr const_iterator make_const_iterator(node_id_type id) const { - return Iterator{ this, id }; + return const_iterator{ this, id }; } template diff --git a/tests/avl_tree_tests.cpp b/tests/avl_tree_tests.cpp index b180b81..6251cbe 100644 --- a/tests/avl_tree_tests.cpp +++ b/tests/avl_tree_tests.cpp @@ -44,7 +44,7 @@ TEST_CASE("Simple avl tests") REQUIRE_EQ(it, tree.end()); } - SUBCASE("We should be able to use -- to iteratte") + SUBCASE("We should be able to use -- to iterate") { auto it = tree.find(10); REQUIRE_EQ(it->key(), 10); @@ -54,6 +54,16 @@ TEST_CASE("Simple avl tests") REQUIRE_EQ(it->key(), 100); REQUIRE_EQ(it->value(), 10); } + + SUBCASE("Clearing should return all nodes to zero") + { + tree.clear(); + REQUIRE_EQ(tree.size(), 0); + + // inserting should work again + auto res = tree.try_emplace(10, 0xf001); + REQUIRE_UNARY(res.second); + } } TEST_CASE("AVL Tree iterator tests") @@ -93,6 +103,18 @@ TEST_CASE("AVL Tree iterator tests") } } + SUBCASE("Const Forward iteration should work") + { + auto it = tree.cbegin(); + int i = 0; + while (it != tree.cend()) { + REQUIRE_EQ(it->key(), i); + REQUIRE_EQ(it->value(), N - i); + it++; + i++; + } + } + SUBCASE("Reverse iteration should work") { auto it = tree.rbegin(); @@ -104,6 +126,18 @@ TEST_CASE("AVL Tree iterator tests") i--; } } + + SUBCASE("Const Reverse iteration should work") + { + auto it = tree.crbegin(); + int i = N - 1; + while (it != tree.crend()) { + REQUIRE_EQ(it->key(), i); + REQUIRE_EQ(it->value(), N - i); + it++; + i--; + } + } } TEST_SUITE_END();