From 3bbc9852ec4907099604f4642fd73d0201a4af2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teodor=20Sp=C3=A6ren?= Date: Thu, 11 Jul 2024 10:28:25 +0200 Subject: [PATCH] ds: Work on implementing delete for AVL tree This is taking quite some work, but I've invested in a graphivz renderer for the AVL tree, which I hope will help me figure out what is going wrong in the rebalancing part. --- README.md | 12 +- include/hage/ds/avl_tree.hpp | 470 +++++++++++++++++++++++++++++++---- tests/avl_tree_tests.cpp | 72 +++++- 3 files changed, 498 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index e585d70..a68b124 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,8 @@ This library includes datastructures that I used from time to time. Planned datastructures: - Slotmap implementation, take inspiration from rust - Index based linked list, with a skip list -- Interval map, built on top of +- Interval map, built on top of a map structure. +- A skip list #### AVLTree @@ -75,20 +76,29 @@ An AVLTree implementation, currently in the process of being completed. Features ##### AVLTREE - Add `generation` tag to all nodes in debug mode, and check this tag when accessing through an iterator, to detect dangling pointers! +- Add check in debug mode, that the iterator belongs to this tree. + - Just a simple check to see if `m_tree` is equal to `this`. + - Actually, according to the standard library, iterators must remain valid. + - I can do this by introducing a "tree_id" which I get from a thread safe function, + and compare it to that. This get 90% of the effect for 10% of the effort. - Add examples and tests for this - Add allocator support. - Add support for erasing! - Add some statistics? - Add support for custom comperators - Add support for multimap? + - Look at how boost does this - Add support for sets also, that doesn't take up extra space. - This would be a good generalization, we can just vary the node type. +- Figure out a way to remove the need for a default constructable value type. ##### General - Also implement a redblack tree - Add some benchmarks for all trees - Add a B-Tree? +- Add WAVL tree? +- Add 2-3 trees? ### Utility library diff --git a/include/hage/ds/avl_tree.hpp b/include/hage/ds/avl_tree.hpp index 6b7aadc..31b9702 100644 --- a/include/hage/ds/avl_tree.hpp +++ b/include/hage/ds/avl_tree.hpp @@ -2,7 +2,9 @@ #include #include +#include #include +#include #include namespace hage::ds { @@ -70,6 +72,13 @@ class AVLTree } } + constexpr iterator erase(iterator pos) + { + auto id = internal_erase(pos.m_id); + // std::ranges::copy_backward + return make_iterator(id); + } + [[nodiscard]] constexpr bool contains(const Key& key) const { return internal_find(m_root, key) != -1; } [[nodiscard]] constexpr size_type size() const { return m_size; } @@ -123,6 +132,24 @@ class AVLTree return std::make_reverse_iterator(make_const_iterator(m_begin)); } + [[nodiscard]] std::string print_tree() const + { + std::ostringstream out; + internal_print(out, m_root, 0); + return out.str(); + } + + [[nodiscard]] std::string print_dot_tree() const + { + DotPrinter printer; + printer.add_default_node_attrib("ordering", "out"); + printer.add_default_node_attrib("colorscheme", "piyg5"); + printer.add_default_node_attrib("style", "filled"); + + internal_dot_print(printer, m_root, 0); + return printer.print(); + } + private: class Node { @@ -195,71 +222,20 @@ class AVLTree // Prefix increment constexpr Iterator& operator++() { - const auto& nodes = m_tree->m_nodes; - - if (m_id == m_tree->m_end) { - throw std::runtime_error("Invalid iterator usage, trying to increment beyond end"); - } - - if (m_id == nodes[m_tree->m_end].m_parent) { - m_id = m_tree->m_end; - return *this; - } - - // ok, let's be a bit smart here. We want to go to the right. There are two ways this will go. - // Either we are going up, or we are going down. - if (nodes[m_id].m_right == -1) { - auto prev = m_id; - m_id = nodes[m_id].m_parent; - while (nodes[m_id].m_right == prev) { - prev = m_id; - m_id = nodes[m_id].m_parent; - } - } else { - m_id = nodes[m_id].m_right; - while (nodes[m_id].m_left != -1) { - m_id = nodes[m_id].m_left; - } - } - + m_id = m_tree->next_node(m_id); return *this; } constexpr Iterator operator--(int) { auto it = *this; - --(*this); + this->operator--(); return it; } constexpr Iterator& operator--() { - const auto& nodes = m_tree->m_nodes; - - if (m_id == m_tree->m_begin) { - throw std::runtime_error("We tried to decrement a begin iterator"); - } - - if (m_id == m_tree->m_end) { - m_id = nodes[m_id].m_parent; - return *this; - } - - if (nodes[m_id].m_left == -1) { - // We need to go up again, until we are on the right side. - auto prev = m_id; - m_id = nodes[m_id].m_parent; - while (nodes[m_id].m_left == prev) { - prev = m_id; - m_id = nodes[m_id].m_parent; - } - } else { - m_id = nodes[m_id].m_left; - while (nodes[m_id].m_right != -1) { - m_id = nodes[m_id].m_right; - } - } - + m_id = m_tree->prev_node(m_id); return *this; } @@ -604,6 +580,392 @@ class AVLTree yc->m_balance = 0; return Y; } + + [[nodiscard]] constexpr node_id_type next_node(node_id_type cur) const + { + if (cur == m_end) { + throw std::runtime_error("Invalid iterator usage, trying to increment beyond end"); + } + + if (cur == m_nodes[m_end].m_parent) { + return m_end; + } + + // ok, let's be a bit smart here. We want to go to the right. There are two ways this will go. + // Either we are going up, or we are going down. + if (m_nodes[cur].m_right == -1) { + auto prev = cur; + cur = m_nodes[cur].m_parent; + while (m_nodes[cur].m_right == prev) { + prev = cur; + cur = m_nodes[cur].m_parent; + } + } else { + cur = m_nodes[cur].m_right; + while (m_nodes[cur].m_left != -1) { + cur = m_nodes[cur].m_left; + } + } + + return cur; + } + + [[nodiscard]] constexpr node_id_type prev_node(node_id_type cur) const + { + if (cur == m_begin) { + throw std::runtime_error("We tried to decrement a begin iterator"); + } + + if (cur == m_end) { + return m_nodes[cur].m_parent; + } + + if (m_nodes[cur].m_left == -1) { + // We need to go up again, until we are on the right side. + auto prev = cur; + cur = m_nodes[cur].m_parent; + while (m_nodes[cur].m_left == prev) { + prev = cur; + cur = m_nodes[cur].m_parent; + } + } else { + cur = m_nodes[cur].m_left; + while (m_nodes[cur].m_right != -1) { + cur = m_nodes[cur].m_right; + } + } + + return cur; + } + + [[nodiscard]] constexpr node_id_type internal_erase(const node_id_type root) + { + auto shiftNodes = [&](const node_id_type orig, const node_id_type replace) { + auto parent = m_nodes[orig].m_parent; + if (parent == -1) { + m_root = replace; + } else if (orig == m_nodes[parent].m_left) { + m_nodes[parent].m_left = replace; + } else { + m_nodes[parent].m_right = replace; + } + + if (replace != -1) { + m_nodes[replace].m_parent = parent; + } + + return replace; + }; + + const auto retValue = next_node(root); + + // If we are going to delete the end root, we need to update. + if (root == m_nodes[m_end].m_parent) + m_nodes[m_end].m_parent = m_nodes[root].m_parent; + + if (root == m_begin) { + m_begin = retValue; + } + + // ok, we are going to + + node_id_type newRoot = -1; + + auto& node = m_nodes[root]; + if (node.m_left == -1) { + newRoot = shiftNodes(root, node.m_right); + } else if (node.m_right == -1) { + newRoot = shiftNodes(root, node.m_left); + } else { + // This is where we need change also internal balancing nodes. + + auto next = retValue; + // ah, first we remove ourselves from this spot in the tree, + // and then we move ourselves into the tree. + if (m_nodes[next].m_parent != root) { + auto par = m_nodes[next].m_parent; + if (par != m_end) { + m_nodes[par].m_balance++; + } + + // This is correct, we need to run up a rebalance from this spot, until you reach the parent then? + + shiftNodes(next, m_nodes[next].m_right); + m_nodes[next].m_right = node.m_right; + m_nodes[node.m_right].m_parent = next; + } else { + // hey + std::ignore = 1; + } + newRoot = shiftNodes(root, next); + m_nodes[next].m_left = node.m_left; + m_nodes[node.m_left].m_parent = next; + } + + // remove the root node. + free_node(root); + --m_size; + + if (newRoot == -1) { + return retValue; + } + + // OK, time to rebalance the tree, let's hope I can do it. With wikipedia as my guide it can work. + node_id_type G = -1; + for (auto parentNode = m_nodes[newRoot].m_parent; parentNode != -1; parentNode = G) { + G = m_nodes[parentNode].m_parent; + int b = 0; + + // BF(X) is not yet updated. + if (newRoot == m_nodes[parentNode].m_left) { // it's the left tree which decreases + if (0 < m_nodes[parentNode].m_balance) { // X is right heavy + // THe temporary BF(X) == +2 amd we need to rebalance + // we need to rebalance + auto Z = m_nodes[parentNode].m_right; + b = m_nodes[Z].m_balance; + if (b < 0) { + newRoot = rotate_right_left(parentNode, Z); + } else { + newRoot = rotate_left(parentNode, Z); + } + } else { + auto pre = m_nodes[parentNode].m_balance++; + if (pre == 0) + break; + + newRoot = parentNode; + continue; + } + } else { + // We are in the right subtree + if (m_nodes[parentNode].m_balance < 0) { + // parentNode is left heavy and we need to rebalance + auto Z = m_nodes[parentNode].m_left; + b = m_nodes[Z].m_balance; + + if (0 < b) { + newRoot = rotate_left_right(parentNode, Z); + } else { + newRoot = rotate_right(parentNode, Z); + } + } else { + auto pre = m_nodes[parentNode].m_balance--; + if (pre == 0) + break; + + newRoot = parentNode; + continue; + } + } + + // After a rotation adapt parent link: + // newRoot is the new root of the rotated subtree + m_nodes[newRoot].m_parent = G; + if (G != -1) { + if (parentNode == m_nodes[G].m_left) { + m_nodes[G].m_left = newRoot; + } else { + m_nodes[G].m_right = newRoot; + } + } else { + m_root = newRoot; + } + + if (b == 0) + break; + } + + // Now let's think here. + return retValue; + } + + constexpr void internal_print(std::ostringstream& out, const node_id_type cur, const int depth) const + { + std::string padding(depth * 2, '_'); + + if (cur == -1) { + out << padding << "[NO NODE]\n"; + return; + } + + const auto& node = m_nodes[cur]; + + out << padding << "[ .id=" << cur << ", .bal=" << static_cast(node.m_balance); + out << ", .key=" << node.m_key; + // out << ", .val=" << node.m_value; + out << "]"; + if (cur == m_root) { + out << " (root)"; + } + + if (cur == m_begin) { + out << " (begin)"; + } + + if (cur == m_nodes[m_end].m_parent) { + out << " (end)"; + } + out << "\n"; + + internal_print(out, node.m_left, depth + 1); + internal_print(out, node.m_right, depth + 1); + } + + class DotPrinter + { + int m_nullNodes{ 0 }; + std::unordered_map m_defaultNodeAttrib; + std::unordered_map> m_nodes; + std::vector> m_edges; + + std::string get_null_id() + { + std::stringstream ss; + ss << "null"; + ss << m_nullNodes++; + return ss.str(); + } + + void do_format_attributes(std::ostream& out, const std::unordered_map& attribs) const + { + if (!attribs.empty()) { + out << " ["; + bool first = true; + for (const auto& [name, val] : attribs) { + if (!first) { + out << ","; + } else { + first = false; + } + out << name << "=" << val; + } + out << "]"; + } + } + + public: + void add_default_node_attrib(const std::string& key, const std::string& val) { m_defaultNodeAttrib[key] = val; } + + std::string get_node_id(int id) + { + std::stringstream ss; + ss << "node"; + ss << id; + return ss.str(); + } + + void add_node_attrib(const std::string& sid, const std::string& key, const std::string& val, bool isHtml) + { + if (isHtml) { + m_nodes[sid][key] = "<" + val + ">"; + } else { + m_nodes[sid][key] = "\"" + val + "\""; + } + } + + void add_node_attrib(const std::string& sid, const std::string& key, const int val, bool isHtml) + { + std::ostringstream out; + out << val; + add_node_attrib(sid, key, out.str(), isHtml); + } + + std::string add_node(int id, int balance, int key) + { + const auto sid = get_node_id(id); + std::ostringstream label; + // label << "id = " << id << ", bal = " << balance << ", key = " << key; + label << key << "" << id << ""; + add_node_attrib(sid, "label", label.str(), true); + + label.str(std::string()); + // label << id << " (" << balance << ")"; + label << "(" << balance << ")"; + add_node_attrib(sid, "tooltip", label.str(), false); + + // add_node_attrib(sid, "style", "filled"); + add_node_attrib(sid, "fillcolor", 3 + balance, false); + + return sid; + } + + std::string get_null_node() + { + const auto sid = get_null_id(); + add_node_attrib(sid, "shape", "point", false); + return sid; + } + + void add_edge(const std::string& src, const std::string& dst) { m_edges.emplace_back(src, dst); } + + [[nodiscard]] std::string print() const + { + std::ostringstream out; + out << "digraph BST {\n"; + + out << "\tnode"; + do_format_attributes(out, m_defaultNodeAttrib); + out << ";\n"; + + // Nodes + for (const auto& [sid, attribs] : m_nodes) { + out << "\t" << sid; + do_format_attributes(out, attribs); + out << ";\n"; + } + + // Edges + for (const auto& [src, dst] : m_edges) { + out << src << " -> " << dst << "\n"; + } + + out << "}\n"; + + return out.str(); + } + }; + + constexpr void internal_dot_print(DotPrinter& printer, const node_id_type cur, const int depth) const + { + const auto& node = m_nodes[cur]; + + const auto sid = printer.add_node(cur, node.m_balance, node.key()); + + std::string xlabel; + if (cur == m_root) { + xlabel += " (root)"; + } + + if (cur == m_begin) { + xlabel += " (begin)"; + } + + if (cur == m_nodes[m_end].m_parent) { + xlabel += " (end)"; + } + + if (!xlabel.empty()) { + printer.add_node_attrib(sid, "xlabel", xlabel, false); + } + + if (node.m_left != -1) { + const auto kid = printer.get_node_id(node.m_left); + printer.add_edge(sid, kid); + internal_dot_print(printer, node.m_left, depth + 1); + } else { + const auto kid = printer.get_null_node(); + printer.add_edge(sid, kid); + } + + if (node.m_right != -1) { + const auto kid = printer.get_node_id(node.m_right); + printer.add_edge(sid, kid); + internal_dot_print(printer, node.m_right, depth + 1); + } else { + const auto kid = printer.get_null_node(); + printer.add_edge(sid, kid); + } + } }; } // namespace hage::ds \ No newline at end of file diff --git a/tests/avl_tree_tests.cpp b/tests/avl_tree_tests.cpp index 6251cbe..4501dc7 100644 --- a/tests/avl_tree_tests.cpp +++ b/tests/avl_tree_tests.cpp @@ -2,6 +2,12 @@ #include +#include +#include +#include +#include +#include + using namespace hage; TEST_SUITE_BEGIN("data_structures"); @@ -55,6 +61,18 @@ TEST_CASE("Simple avl tests") REQUIRE_EQ(it->value(), 10); } + SUBCASE("Erase should work") + { + auto it = tree.find(10); + REQUIRE_NE(it, tree.end()); + REQUIRE_UNARY(tree.contains(10)); + + auto it2 = tree.erase(it); + REQUIRE_NE(it2, tree.end()); + REQUIRE_EQ(it2->key(), 100); + REQUIRE_UNARY_FALSE(tree.contains(10)); + } + SUBCASE("Clearing should return all nodes to zero") { tree.clear(); @@ -80,7 +98,7 @@ TEST_CASE("AVL Tree iterator tests") REQUIRE_UNARY(std::bidirectional_iterator); } - constexpr int N = 10; + constexpr int N = 100; // We have to insert a 100 elements. for (int i = 0; i < N; i++) { auto [it, inserted] = tree.try_emplace(i, N - i); @@ -140,4 +158,56 @@ TEST_CASE("AVL Tree iterator tests") } } +TEST_CASE("AVL erase tests") +{ + ds::AVLTree tree; + std::map mirror; + + constexpr int N = 20; + // We have to insert a 100 elements. + for (int i = 0; i < N; i++) { + auto [it, inserted] = tree.try_emplace(i, N - i); + REQUIRE_UNARY(inserted); + + mirror.emplace(i, N - i); + + auto itEnd = tree.end(); + auto itPrev = std::prev(itEnd); + REQUIRE_EQ(it, itPrev); + } + + // Ok, here we go + std::vector toDelete(N); + std::iota(toDelete.begin(), toDelete.end(), 0); + + SUBCASE("Erasing randomly should work") + { + std::ranlux48 rng{ 24 }; + std::ranges::shuffle(toDelete, rng); + for (auto toDel : toDelete) { + INFO("Doing: ", toDel); + + std::cout << "Going to delete: " << toDel << "\n"; + std::cout << tree.print_dot_tree(); + + auto it1 = tree.find(toDel); + REQUIRE_NE(it1, tree.end()); + auto it2 = mirror.find(toDel); + + REQUIRE_EQ(it1->key(), it2->first); + REQUIRE_EQ(it1->value(), it2->second); + + auto it3 = tree.erase(it1); + auto it4 = mirror.erase(it2); + + if (it3 == tree.end()) { + REQUIRE_EQ(it4, mirror.end()); + } else { + REQUIRE_EQ(it3->key(), it4->first); + REQUIRE_EQ(it3->value(), it4->second); + } + } + } +} + TEST_SUITE_END();