Skip to content

Commit

Permalink
ds: Fix ++ and -- on iterators and implement const
Browse files Browse the repository at this point in the history
  • Loading branch information
rHermes committed Jun 23, 2024
1 parent 059a51b commit 212de76
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 20 deletions.
93 changes: 73 additions & 20 deletions include/hage/ds/avl_tree.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cstdint>
#include <iterator>
#include <stdexcept>
#include <vector>

Expand All @@ -12,17 +13,20 @@ template<typename Key, typename Value>
class AVLTree
{
private:
template<bool IsConst>
class Iterator;
class ConstIterator;

class Node;

using node_id_type = std::int32_t;

public:
using value_type = Value;
using size_type = std::size_t;
using iterator = Iterator;
using const_iterator = ConstIterator;
using iterator = Iterator<false>;
using const_iterator = Iterator<true>;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;

AVLTree()
{
Expand All @@ -31,19 +35,19 @@ class AVLTree
}

template<typename K, typename... Args>
constexpr std::pair<Iterator, bool> try_emplace(K&& key, Args&&... args)
constexpr std::pair<iterator, bool> try_emplace(K&& key, Args&&... args)
{
if (m_root == -1) {
m_root = get_node(std::forward<K>(key), Value{ std::forward<Args>(args)... });
m_begin = m_root;
m_nodes[m_end].m_parent = m_root;

m_size++;
return { make_iterator(m_root), true };
return { make_iterator<false>(m_root), true };
}

const auto [id, inserted] = internal_try_emplace(m_root, std::forward<K>(key), std::forward<Args>(args)...);
return { make_iterator(id), inserted };
return { make_iterator<false>(id), inserted };
}

[[nodiscard]] constexpr iterator find(const Key& key)
Expand All @@ -52,12 +56,45 @@ class AVLTree
if (id == -1) {
return end();
} else {
return make_iterator(id);
return make_iterator<false>(id);
}
}

[[nodiscard]] constexpr iterator end() noexcept { return make_iterator(m_end); }
[[nodiscard]] constexpr iterator begin() noexcept { return make_iterator(m_begin); }
[[nodiscard]] constexpr iterator end() noexcept { return make_iterator<false>(m_end); }
[[nodiscard]] constexpr iterator begin() noexcept { return make_iterator<false>(m_begin); }

[[nodiscard]] constexpr const_iterator end() const noexcept { return make_iterator<true>(m_end); }
[[nodiscard]] constexpr const_iterator begin() const noexcept { return make_iterator<true>(m_begin); }

[[nodiscard]] constexpr const_iterator cend() const noexcept { return make_iterator<true>(m_end); }
[[nodiscard]] constexpr const_iterator cbegin() const noexcept { return make_iterator<true>(m_begin); }

[[nodiscard]] constexpr reverse_iterator rbegin() noexcept
{
return std::make_reverse_iterator(make_iterator<false>(m_end));
}
[[nodiscard]] constexpr reverse_iterator rend() noexcept
{
return std::make_reverse_iterator(make_iterator<false>(m_begin));
}

[[nodiscard]] constexpr reverse_iterator rbegin() const noexcept
{
return std::make_reverse_iterator(make_iterator<true>(m_end));
}
[[nodiscard]] constexpr reverse_iterator rend() const noexcept
{
return std::make_reverse_iterator(make_iterator<true>(m_begin));
}

[[nodiscard]] constexpr reverse_iterator crbegin() const noexcept
{
return std::make_reverse_iterator(make_iterator<true>(m_end));
}
[[nodiscard]] constexpr reverse_iterator crend() const noexcept
{
return std::make_reverse_iterator(make_iterator<true>(m_begin));
}

[[nodiscard]] constexpr bool contains(const Key& key) const { return internal_find(m_root, key) != -1; }

Expand Down Expand Up @@ -87,8 +124,9 @@ class AVLTree
{
}

const Key& key() const { return m_key; }
Value& value() { return m_value; }
[[nodiscard]] constexpr const Key& key() const { return m_key; }
[[nodiscard]] constexpr Value& value() { return m_value; }
[[nodiscard]] constexpr const Value& value() const { return m_value; }
};

std::size_t m_size{ 0 };
Expand All @@ -99,19 +137,24 @@ class AVLTree
node_id_type m_begin{ -1 };
node_id_type m_end{ -1 };

template<bool IsConst>
class Iterator
{
private:
using node_type = std::conditional_t<IsConst, const Node, Node>;

public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = Node;
using value_type = node_type;
using difference_type = std::int32_t;
using pointer = Node*;
using reference = Node&;
using pointer = node_type*;
using reference = node_type&;

Iterator() = default;
Iterator(AVLTree* tree, node_id_type id) : m_tree{ tree }, m_id{ id } {}

constexpr reference operator*() const { return m_tree->m_nodes[m_id]; }
constexpr pointer operator->() { return &m_tree->m_nodes[m_id]; }
constexpr pointer operator->() const { return &m_tree->m_nodes[m_id]; }

friend constexpr bool operator==(const Iterator& lhs, const Iterator& rhs)
{
Expand All @@ -123,7 +166,7 @@ class AVLTree
{
// BUG(rHermes): Should we return the value it had before iterating?
auto it = *this;
++(*this);
this->operator++();
return it;
}
// Prefix increment
Expand All @@ -143,7 +186,10 @@ class AVLTree
// ok, let's be a bit smart here. We want to go to the right. There are two ways this will go.
// Either we are going up, or we are going down.
if (nodes[m_id].m_right == -1) {
while (nodes[nodes[m_id].m_parent].m_right == m_id) {
auto prev = m_id;
m_id = nodes[m_id].m_parent;
while (nodes[m_id].m_right == prev) {
prev = m_id;
m_id = nodes[m_id].m_parent;
}
} else {
Expand Down Expand Up @@ -173,12 +219,15 @@ class AVLTree

if (m_id == m_tree->m_end) {
m_id = nodes[m_id].m_parent;
return *this;
}

if (nodes[m_id].m_left == -1) {

// We need to go up again, until we are on the right side.
while (nodes[nodes[m_id].m_parent].m_left == m_id) {
auto prev = m_id;
m_id = nodes[m_id].m_parent;
while (nodes[m_id].m_left == prev) {
prev = m_id;
m_id = nodes[m_id].m_parent;
}
} else {
Expand All @@ -201,7 +250,11 @@ class AVLTree
// iterators a bit differently.
};

[[nodiscard]] constexpr Iterator make_iterator(node_id_type id) { return Iterator{ this, id }; }
template<bool IsConst>
[[nodiscard]] constexpr Iterator<IsConst> make_iterator(node_id_type id)
{
return Iterator<IsConst>{ this, id };
}

template<typename K, typename V>
[[nodiscard]] constexpr node_id_type get_node(K&& key, V&& value)
Expand Down
22 changes: 22 additions & 0 deletions tests/avl_tree_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ TEST_CASE("AVL Tree iterator tests")
{
ds::AVLTree<int, int> tree;

SUBCASE("Forward iterator has to be a bidirectional iterator")
{
REQUIRE_UNARY(std::bidirectional_iterator<decltype(tree)::iterator>);
}

SUBCASE("Reverse iterator has to be a bidirectional iterator")
{
REQUIRE_UNARY(std::bidirectional_iterator<decltype(tree)::reverse_iterator>);
}

constexpr int N = 10;
// We have to insert a 100 elements.
for (int i = 0; i < N; i++) {
Expand All @@ -82,6 +92,18 @@ TEST_CASE("AVL Tree iterator tests")
i++;
}
}

SUBCASE("Reverse iteration should work")
{
auto it = tree.rbegin();
int i = N - 1;
while (it != tree.rend()) {
REQUIRE_EQ(it->key(), i);
REQUIRE_EQ(it->value(), N - i);
it++;
i--;
}
}
}

TEST_SUITE_END();

0 comments on commit 212de76

Please sign in to comment.