From 1665662bdd44c917d16ab649bde495981d5dfc84 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin <552936+mbautin@users.noreply.github.com> Date: Sat, 21 Sep 2024 01:09:08 -0700 Subject: [PATCH] Automatic commit by thirdparty_tool: update usearch to commit 240fe9c298100f9e37a2d7377b1595be6ba1f412. Used commit of the usearch repository: https://github.com/unum-cloud/usearch/commits/240fe9c298100f9e37a2d7377b1595be6ba1f412 Latest commit in the include subdirectory of the usearch repository: ad8656ec01043e4a33e38088e5a7ea549eac3201 --- .../usearch/usearch/index.hpp | 1174 +++++++++++++---- .../usearch/usearch/index_dense.hpp | 694 ++++++---- .../usearch/usearch/index_plugins.hpp | 668 +++++++--- 3 files changed, 1873 insertions(+), 663 deletions(-) diff --git a/src/inline-thirdparty/usearch/usearch/index.hpp b/src/inline-thirdparty/usearch/usearch/index.hpp index d31c1554791d..6dbdb5e95291 100644 --- a/src/inline-thirdparty/usearch/usearch/index.hpp +++ b/src/inline-thirdparty/usearch/usearch/index.hpp @@ -1,23 +1,24 @@ /** - * @file index.hpp - * @author Ash Vardanian - * @brief Single-header Vector Search. - * @date 2023-04-26 - * - * @copyright Copyright (c) 2023 + * @file index.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine. + * @date April 26, 2023 */ #ifndef UNUM_USEARCH_HPP #define UNUM_USEARCH_HPP #define USEARCH_VERSION_MAJOR 2 -#define USEARCH_VERSION_MINOR 11 -#define USEARCH_VERSION_PATCH 0 +#define USEARCH_VERSION_MINOR 15 +#define USEARCH_VERSION_PATCH 1 // Inferring C++ version // https://stackoverflow.com/a/61552074 #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) #define USEARCH_DEFINED_CPP17 #endif +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) || __cplusplus >= 202002L) +#define USEARCH_DEFINED_CPP20 +#endif // Inferring target OS: Windows, MacOS, or Linux #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) @@ -35,7 +36,11 @@ #define USEARCH_DEFINED_GCC #endif -#if defined(__clang__) || defined(_MSC_VER) +// The `#pragma region` and `#pragma endregion` are not supported by GCC 12 and older. +// But they are supported by GCC 13, all recent Clang versions, and MSVC. +#if defined(__GNUC__) && ((__GNUC__ > 13) || (__GNUC__ == 13 && __GNUC_MINOR__ >= 0)) +#define USEARCH_USE_PRAGMA_REGION +#elif defined(__clang__) || defined(_MSC_VER) #define USEARCH_USE_PRAGMA_REGION #endif @@ -90,17 +95,39 @@ #include // `std::thread` #include // `std::pair` +// Helper macros for concatenation and stringification +#define usearch_concat_helper_m(a, b) a##b +#define usearch_concat_m(a, b) usearch_concat_helper_m(a, b) +#define usearch_stringify_helper_m(x) #x +#define usearch_stringify_m(x) usearch_stringify_helper_m(x) + // Prefetching #if defined(USEARCH_DEFINED_GCC) // https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html // Zero means we are only going to read from that memory. // Three means high temporal locality and suggests to keep // the data in all layers of cache. -#define prefetch_m(ptr) __builtin_prefetch((void*)(ptr), 0, 3) +#define usearch_prefetch_m(ptr) __builtin_prefetch((void*)(ptr), 0, 3) #elif defined(USEARCH_DEFINED_X86) -#define prefetch_m(ptr) _mm_prefetch((void*)(ptr), _MM_HINT_T0) +#define usearch_prefetch_m(ptr) _mm_prefetch((void*)(ptr), _MM_HINT_T0) +#else +#define usearch_prefetch_m(ptr) +#endif + +// Function profiling +#if defined(usearch_defined_x86) +#define usearch_profiled_m __attribute__((noinline)) +#define usearch_profile_name_m(name) \ + __asm__ volatile(".globl " usearch_stringify_m(usearch_concat_m(name, __COUNTER__)) "\n" usearch_stringify_m( \ + usearch_concat_m(name, __COUNTER__)) ":") +#elif defined(usearch_defined_arm) +#define usearch_profiled_m __attribute__((noinline)) +#define usearch_profile_name_m(name) \ + __asm__ volatile(".global " usearch_stringify_m(usearch_concat_m(name, __COUNTER__)) "\n" usearch_stringify_m( \ + usearch_concat_m(name, __COUNTER__)) ":") #else -#define prefetch_m(ptr) +#define usearch_profiled_m +#define usearch_profile_name_m(name) #endif // Alignment @@ -119,11 +146,25 @@ #else #define usearch_assert_m(must_be_true, message) \ if (!(must_be_true)) { \ - throw std::runtime_error(message); \ + __usearch_raise_runtime_error(message); \ } #define usearch_noexcept_m #endif +extern "C" { +/// @brief Helper function to simplify debugging - trace just one symbol - `__usearch_raise_runtime_error`. +/// Assuming the `extern C` block, the name won't be mangled. +inline static void __usearch_raise_runtime_error(char const* message) { + // On Windows we compile with `/EHc` flag, which specifies that functions + // with C linkage do not throw C++ exceptions. +#if !defined(__cpp_exceptions) || defined(USEARCH_DEFINED_WINDOWS) + std::terminate(); +#else + throw std::runtime_error(message); +#endif +} +} + namespace unum { namespace usearch { @@ -172,6 +213,13 @@ template at exchange(at& obj, other_at&& n return old_value; } +#if defined(USEARCH_DEFINED_CPP20) + +template void destroy_at(at* obj) { std::destroy_at(obj); } +template void construct_at(at* obj) { std::construct_at(obj); } + +#else + /// @brief The `std::destroy_at` alternative for C++11. template typename std::enable_if::value>::type destroy_at(at*) {} @@ -188,6 +236,8 @@ typename std::enable_if::value>::type construct_at(at* o new (obj) at(); } +#endif + /** * @brief A reference to a misaligned memory location with a specific type. * It is needed to avoid Undefined Behavior when dereferencing addresses @@ -227,27 +277,57 @@ template class misaligned_ptr_gt { using pointer = misaligned_ptr_gt; using reference = misaligned_ref_gt; + misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + reference operator*() const noexcept { return {ptr_}; } reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } value_type operator[](std::size_t i) const noexcept { return misaligned_load(ptr_ + i * sizeof(element_t)); } - misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} - misaligned_ptr_gt operator++(int) noexcept { return misaligned_ptr_gt(ptr_ + sizeof(element_t)); } - misaligned_ptr_gt operator--(int) noexcept { return misaligned_ptr_gt(ptr_ - sizeof(element_t)); } - misaligned_ptr_gt operator+(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); } - misaligned_ptr_gt operator-(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); } - - // clang-format off - misaligned_ptr_gt& operator++() noexcept { ptr_ += sizeof(element_t); return *this; } - misaligned_ptr_gt& operator--() noexcept { ptr_ -= sizeof(element_t); return *this; } - misaligned_ptr_gt& operator+=(difference_type d) noexcept { ptr_ += d * sizeof(element_t); return *this; } - misaligned_ptr_gt& operator-=(difference_type d) noexcept { ptr_ -= d * sizeof(element_t); return *this; } - // clang-format on - - bool operator==(misaligned_ptr_gt const& other) noexcept { return ptr_ == other.ptr_; } - bool operator!=(misaligned_ptr_gt const& other) noexcept { return ptr_ != other.ptr_; } + misaligned_ptr_gt& operator++() noexcept { + ptr_ += sizeof(element_t); + return *this; + } + misaligned_ptr_gt& operator--() noexcept { + ptr_ -= sizeof(element_t); + return *this; + } + misaligned_ptr_gt operator++(int) noexcept { + misaligned_ptr_gt tmp = *this; + ++(*this); + return tmp; + } + misaligned_ptr_gt operator--(int) noexcept { + misaligned_ptr_gt tmp = *this; + --(*this); + return tmp; + } + misaligned_ptr_gt operator+(difference_type d) const noexcept { + return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); + } + misaligned_ptr_gt operator-(difference_type d) const noexcept { + return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); + } + difference_type operator-(const misaligned_ptr_gt& other) const noexcept { + return (ptr_ - other.ptr_) / sizeof(element_t); + } + + misaligned_ptr_gt& operator+=(difference_type d) noexcept { + ptr_ += d * sizeof(element_t); + return *this; + } + misaligned_ptr_gt& operator-=(difference_type d) noexcept { + ptr_ -= d * sizeof(element_t); + return *this; + } + + bool operator==(misaligned_ptr_gt const& other) const noexcept { return ptr_ == other.ptr_; } + bool operator!=(misaligned_ptr_gt const& other) const noexcept { return ptr_ != other.ptr_; } + bool operator<(misaligned_ptr_gt const& other) const noexcept { return ptr_ < other.ptr_; } + bool operator<=(misaligned_ptr_gt const& other) const noexcept { return ptr_ <= other.ptr_; } + bool operator>(misaligned_ptr_gt const& other) const noexcept { return ptr_ > other.ptr_; } + bool operator>=(misaligned_ptr_gt const& other) const noexcept { return ptr_ >= other.ptr_; } }; /** @@ -283,7 +363,8 @@ template > for (std::size_t i = 0; i != size_; ++i) construct_at(data_ + i); } - ~buffer_gt() noexcept { + ~buffer_gt() noexcept { reset(); } + void reset() noexcept { if (!std::is_trivially_destructible::value) for (std::size_t i = 0; i != size_; ++i) destroy_at(data_ + i); @@ -323,7 +404,8 @@ class error_t { char const* message_{}; public: - error_t(char const* message = nullptr) noexcept : message_(message) {} + error_t() noexcept : message_(nullptr) {} + error_t(char const* message) noexcept : message_(message) {} error_t& operator=(char const* message) noexcept { message_ = message; return *this; @@ -336,11 +418,18 @@ class error_t { std::swap(message_, other.message_); return *this; } + + /// @brief Checks if there was an error. explicit operator bool() const noexcept { return message_ != nullptr; } + + /// @brief Returns the error message. char const* what() const noexcept { return message_; } + + /// @brief Releases the error message, meaning the caller takes ownership. char const* release() noexcept { return exchange(message_, nullptr); } #if defined(__cpp_exceptions) || defined(__EXCEPTIONS) + /// @brief Destructor raises an exception if an error was recorded. ~error_t() noexcept(false) { #if defined(USEARCH_DEFINED_CPP17) if (message_ && std::uncaught_exceptions() == 0) @@ -349,12 +438,17 @@ class error_t { #endif raise(); } + + /// @brief Throws an exception using to be caught by `try` / `catch`. void raise() noexcept(false) { if (message_) throw std::runtime_error(exchange(message_, nullptr)); } #else + /// @brief Destructor terminates if an error was recorded. ~error_t() noexcept { raise(); } + + /// @brief Terminates if an error was recorded. void raise() noexcept { if (message_) std::terminate(); @@ -367,7 +461,7 @@ class error_t { * or an error. It's used to avoid raising exception, and gracefully propagate * the error. * - * @tparam result_at The type of the expected result. + * @tparam result_at The type of the expected result. */ template struct expected_gt { result_at result; @@ -501,6 +595,64 @@ using bitset_t = bitset_gt<>; * @brief Similar to `std::priority_queue`, but allows raw access to underlying * memory, in case you want to shuffle it or sort. Good for collections * from 100s to 10'000s elements. + * + * In a max-heap, the heap property ensures that the value of each node is greater + * than or equal to the values of its children. This means that the largest element + * is always at the root of the heap. + * + * @section Heap Structures + * + * There are several designs of heaps. Binary heaps are the simplest & most common + * variant, that is easy to implement as a succint array. However, they are not the + * most efficient for all operations. Most importantly, @b melding (merging) of + * two heaps has linear complexity in time. + * + * +-----------------+---------+-----------+---------+--------------+---------+ + * | Operation | find-max| delete-max| insert | increase-key | meld | + * +-----------------+---------+-----------+---------+--------------+---------+ + * | Binary | Θ(1) | Θ(log n) | O(log n)| O(log n) | Θ(n) | + * | Leftist | Θ(1) | Θ(log n) | O(log n)| Θ(log n) | Θ(log n)| + * | Binomial | Θ(1) | Θ(log n) | Θ(1) | Θ(log n) | O(log n)| + * | Skew binomial | Θ(1) | Θ(log n) | Θ(1) | O(log n) | O(log n)| + * | Pairing | Θ(1) | O(log n) | Θ(1) | o(log n) | Θ(1) | + * | Rank-pairing | Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Fibonacci | Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Strict Fibonacci| Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Brodal | Θ(1) | Θ(log n) | Θ(1) | Θ(1) | Θ(1) | + * | 2–3 heap | Θ(1) | O(log n) | Θ(1) | Θ(1) | O(log n)| + * +-----------------+---------+-----------+---------+--------------+---------+ + * + * It's well known, that improved priority queue structures translate into better + * graph-transversal algorithms. For example, Dijkstra's algorithm can be sped up + * by using a Fibonacci heap for arbitrary weights. For integer weight bounded + * by L, Schrijver reported following time complexities in 2004: + * + * +------------+-------------------------------------+----------------------------+--------------------------+ + * | Weights | Algorithm | Time complexity | Author | + * +------------+-------------------------------------+----------------------------+--------------------------+ + * | R | | O(V^2 EL) | Ford 1956 | + * | R | Bellman–Ford algorithm | O(VE) | Shimbel 1955, Bellman | + * | | | | 1958, Moore 1959 | + * | R | | O(V^2 log V) | Dantzig 1960 | + * | R | Dijkstra's with list | O(V^2) | Leyzorek et al. 1957, | + * | | | | Dijkstra 1959... | + * | R | Dijkstra's with binary heap | O((E + V) log V) | Johnson 1977 | + * | R | Dijkstra's with Fibonacci heap | O(E + V log V) | Fredman & Tarjan 1984, | + * | | | | Fredman & Tarjan 1987 | + * | R | Quantum Dijkstra | O(√VE log^2 V) | Dürr et al. 2006 | + * | R | Dial's algorithm (Dijkstra's using | O(E + LV) | Dial 1969 | + * | | a bucket queue with L buckets) | | | + * | N | | O(E log log L) | Johnson 1981, Karlsson & | + * | | | | Poblete 1983 | + * | N | Gabow's algorithm | O(E log_E/V L) | Gabow 1983, Gabow 1985 | + * | N | | O(E + V √log L) | Ahuja et al. 1990 | + * | N | Thorup | O(E + V log log V) | Thorup 2004 | + * +------------+-------------------------------------+----------------------------+--------------------------+ + * + * Possible improvements: + * - Randomized meldable heaps: https://en.wikipedia.org/wiki/Randomized_meldable_heap + * - D-ary heaps: https://en.wikipedia.org/wiki/D-ary_heap + * - B-heap: https://en.wikipedia.org/wiki/B-heap */ template , // is needed before C++14. @@ -551,13 +703,25 @@ class max_heap_gt { inline bool empty() const noexcept { return !size_; } inline std::size_t size() const noexcept { return size_; } inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + inline void clear() noexcept { size_ = 0; } + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } /// @brief Selects the largest element in the heap. /// @return Reference to the stored element. inline element_t const& top() const noexcept { return elements_[0]; } - inline void clear() noexcept { size_ = 0; } - bool reserve(std::size_t new_capacity) noexcept { + /// @brief Invalidates the "max-heap" property, transforming into ascending range. + inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } + + /** + * @brief Ensures the heap has enough capacity for the specified number of elements. + * @param new_capacity The desired minimum capacity. + * @return True if the capacity was successfully increased, false otherwise. + */ + usearch_profiled_m bool reserve(std::size_t new_capacity) noexcept { + usearch_profile_name_m(max_heap_reserve); if (new_capacity < capacity_) return true; @@ -577,6 +741,11 @@ class max_heap_gt { return new_elements; } + /** + * @brief Inserts an element into the heap. + * @param element The element to be inserted. + * @return True if the element was successfully inserted, false otherwise. + */ bool insert(element_t&& element) noexcept { if (!reserve(size_ + 1)) return false; @@ -585,13 +754,33 @@ class max_heap_gt { return true; } - inline void insert_reserved(element_t&& element) noexcept { + /** + * @brief Inserts an element into the heap without reserving additional space. + * @param element The element to be inserted. + */ + usearch_profiled_m void insert_reserved(element_t&& element) noexcept { + usearch_profile_name_m(max_heap_insert_reserved); new (&elements_[size_]) element_t(element); size_++; shift_up(size_ - 1); } - inline element_t pop() noexcept { + /** + * @brief Inserts multiple elements into the heap. + * @param elements Pointer to the elements to be inserted. + * @return True if the elements were successfully inserted, false otherwise. + */ + inline bool insert_many(element_t const* elements) noexcept { + // Wikipedia describes a procedure, due to Floyd, which constructs a heap from an array in linear time. + // It also mentions a procedure for merging two heaps, of sizes 𝑛 and 𝑘, in time 𝑂(𝑘+log𝑘log𝑛). + // Altogether, we can add 𝑘 elements to a heap of length 𝑛 in time 𝑂(𝑘+log𝑘log𝑛): first build a heap containing + // 𝑘 elements to be inserted (takes 𝑂(𝑘) time), then merge that with the heap of size 𝑛 (takes 𝑂(𝑘+log𝑘log𝑛) + // time). Compare this to repeated insertion, which would run in time 𝑂(𝑘log𝑛). + return false; + } + + usearch_profiled_m element_t pop() noexcept { + usearch_profile_name_m(max_heap_pop); element_t result = top(); std::swap(elements_[0], elements_[size_ - 1]); size_--; @@ -600,24 +789,29 @@ class max_heap_gt { return result; } - /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ - inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } - inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } - - inline element_t* data() noexcept { return elements_; } - inline element_t const* data() const noexcept { return elements_; } - private: - inline std::size_t parent_idx(std::size_t i) const noexcept { return (i - 1u) / 2u; } - inline std::size_t left_child_idx(std::size_t i) const noexcept { return (i * 2u) + 1u; } - inline std::size_t right_child_idx(std::size_t i) const noexcept { return (i * 2u) + 2u; } + static std::size_t parent_idx(std::size_t i) noexcept { return (i - 1u) / 2u; } + static std::size_t left_child_idx(std::size_t i) noexcept { return (i * 2u) + 1u; } + static std::size_t right_child_idx(std::size_t i) noexcept { return (i * 2u) + 2u; } static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + /** + * @brief Shifts an element up to maintain the heap property. + * This operation is called when a new element is @b added at the end of the heap. + * The element is moved up until the heap property is restored. + * @param i Index of the element to be shifted up. + */ void shift_up(std::size_t i) noexcept { for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) std::swap(elements_[parent_idx(i)], elements_[i]); } + /** + * @brief Shifts an element down to maintain the heap property. + * This operation is called when the root element is @b removed and the last element is moved to the root. + * The element is moved down until the heap property is restored. + * @param i Index of the element to be shifted down. + */ void shift_down(std::size_t i) noexcept { std::size_t max_idx = i; @@ -762,9 +956,7 @@ class sorted_buffer_gt { #endif /** - * @brief Five-byte integer type to address node clouds with over 4B entries. - * - * @note Avoid usage in 32bit environment + * @brief Five-byte integer type to address node clouds with over 4B entries. */ class usearch_pack_m uint40_t { unsigned char octets[5]; @@ -776,7 +968,10 @@ class usearch_pack_m uint40_t { public: inline uint40_t() noexcept { broadcast(0); } - inline uint40_t(std::uint32_t n) noexcept { std::memcpy(&octets[1], &n, 4); } + inline uint40_t(std::uint32_t n) noexcept { + std::memcpy(&octets, &n, 4); + octets[4] = 0; + } #ifdef USEARCH_64BIT_ENV inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } @@ -793,22 +988,38 @@ class usearch_pack_m uint40_t { std::memcpy(octets, &n, 5); #else std::memcpy(octets, &n, 4); -#endif + octets[4] = 0; +#endif // USEARCH_64BIT_ENV } -#endif +#endif // USEARCH_DEFINED_CLANG && USEARCH_DEFINED_APPLE inline operator std::size_t() const noexcept { std::size_t result = 0; #ifdef USEARCH_64BIT_ENV std::memcpy(&result, octets, 5); #else - std::memcpy(&result, octets + 1, 4); + std::memcpy(&result, octets, 4); #endif return result; } inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } + + inline bool operator==(uint40_t const& other) const noexcept { return std::memcmp(octets, other.octets, 5) == 0; } + inline bool operator!=(uint40_t const& other) const noexcept { return !(*this == other); } + inline bool operator>(uint40_t const& other) const noexcept { return other < *this; } + inline bool operator<=(uint40_t const& other) const noexcept { return !(*this > other); } + inline bool operator>=(uint40_t const& other) const noexcept { return !(*this < other); } + inline bool operator<(uint40_t const& other) const noexcept { + for (int i = 0; i < 5; ++i) { + if (octets[4 - i] < other.octets[4 - i]) + return true; + if (octets[4 - i] > other.octets[4 - i]) + return false; + } + return false; + } }; #if defined(USEARCH_DEFINED_WINDOWS) @@ -817,12 +1028,33 @@ class usearch_pack_m uint40_t { static_assert(sizeof(uint40_t) == 5, "uint40_t must be exactly 5 bytes"); -// clang-format off -template ::value>::type* = nullptr> key_at default_free_value() { return std::numeric_limits::max(); } -template ::value>::type* = nullptr> uint40_t default_free_value() { return uint40_t::max(); } -template ::value && !std::is_same::value>::type* = nullptr> key_at default_free_value() { return key_at(); } -// clang-format on +/** + * @brief Reflection-helper to get the default "unused" value for a given type. + * Needed to initialize hash-sets and bit-sets. + */ +template struct default_free_value_gt { + template ::value>::type* = nullptr> + static sfinae_element_at value() noexcept { + return std::numeric_limits::max(); + } + template ::value>::type* = nullptr> + static sfinae_element_at value() noexcept { + return element_at(); + } +}; + +template <> struct default_free_value_gt { + static uint40_t value() noexcept { return uint40_t::max(); } +}; +template element_at default_free_value() { return default_free_value_gt::value(); } + +/** + * @brief Adapter to allow definining arbitrary hash functions for keys and slots. + * It's added, as overloading `std::hash` is not recommended by the standard. + */ template struct hash_gt { std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } }; @@ -833,6 +1065,7 @@ template <> struct hash_gt { /** * @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. + * In our primary usecase, its a sparse alternative to a bit-set. * * It doesn't support deletion of separate objects, but supports `clear`-ing all at once. * It expects `reserve` to be called ahead of all insertions, so no resizes are needed. @@ -900,6 +1133,10 @@ class growing_hash_set_gt { growing_hash_set_gt(growing_hash_set_gt const&) = delete; growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; + /** + * @brief Checks if the element is already in the hash-set. + * @return `true` if the element is already in the hash-set. + */ inline bool test(element_t const& elem) const noexcept { std::size_t index = hasher_(elem) & (capacity_ - 1); while (slots_[index] != default_free_value()) { @@ -912,7 +1149,7 @@ class growing_hash_set_gt { } /** - * + * @brief Inserts an element into the hash-set. * @return Similar to `bitset_gt`, returns the previous value. */ inline bool set(element_t const& elem) noexcept { @@ -929,6 +1166,10 @@ class growing_hash_set_gt { return false; } + /** + * @brief Extends the capacity of the hash-set. + * @return `true` if enough capacity is available, `false` if memory allocation failed. + */ bool reserve(std::size_t new_capacity) noexcept { new_capacity = (new_capacity * 5u) / 3u; if (new_capacity <= capacity_) @@ -1011,7 +1252,7 @@ class ring_gt { size_t size() const noexcept { if (empty_) return 0; - else if (head_ >= tail_) + else if (head_ > tail_) return head_ - tail_; else return capacity_ - (tail_ - head_); @@ -1056,7 +1297,9 @@ class ring_gt { return true; } - void push(element_t const& value) noexcept { + void push(element_t const& value) usearch_noexcept_m { + usearch_assert_m(capacity() > 0, "Ring buffer is not initialized"); + usearch_assert_m(size() < capacity(), "Ring buffer is full"); elements_[head_] = value; head_ = (head_ + 1) % capacity_; empty_ = false; @@ -1064,7 +1307,7 @@ class ring_gt { bool try_push(element_t const& value) noexcept { if (head_ == tail_ && !empty_) - return false; // elements_ is full + return false; // `elements_` is full return push(value); return true; @@ -1117,20 +1360,49 @@ struct index_config_t { std::size_t connectivity_base = default_connectivity() * 2; inline index_config_t() = default; - inline index_config_t(std::size_t c) noexcept - : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) {} - inline index_config_t(std::size_t c, std::size_t cb) noexcept - : connectivity(c), connectivity_base((std::max)(c, cb)) {} + inline index_config_t(std::size_t c, std::size_t cb = 0) noexcept : connectivity(c), connectivity_base(cb) {} + + /** + * @brief Validates the configuration settings, updating them in-place. + * @return Error message, if any. + */ + inline error_t validate() noexcept { + if (connectivity == 0) + connectivity = default_connectivity(); + if (connectivity_base == 0) + connectivity_base = connectivity * 2; + if (connectivity < 2) + return "Connectivity must be at least 2, otherwise the index degenerates into ropes"; + if (connectivity_base < connectivity) + return "Base layer should be at least as connected as the rest of the graph"; + return {}; + } + + /** + * @brief Immutable function to check if the configuration is valid. + * @return `true` if the configuration is valid. + */ + inline bool is_valid() const noexcept { return connectivity >= 2 && connectivity_base >= connectivity; } }; +/** + * @brief Growth settings for the index container. + * Includes the upper bound for `::members` capacity, + * and the number of read/write threads expected to work with the index. + */ struct index_limits_t { + /// @brief Maximum number of entries in the index. std::size_t members = 0; + /// @brief Max number of threads simultaneously updating entries. std::size_t threads_add = std::thread::hardware_concurrency(); + /// @brief Max number of threads simultaneously searching entries. std::size_t threads_search = std::thread::hardware_concurrency(); inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} + /// @brief Returns the upper limit for the number of threads. inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } + /// @brief Returns the concurrency-level of the index - the minimum of thread counts. inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } }; @@ -1408,8 +1680,10 @@ class input_file_t { serialization_result_t read(void* begin, std::size_t length) noexcept { serialization_result_t result; std::size_t read = std::fread(begin, length, 1, file_); - if (length && !read) - return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); + if (length && !read) { + bool reached_eof = std::feof(file_); + return result.failed(reached_eof ? "End of file reached!" : std::strerror(errno)); + } return result; } void close() noexcept { @@ -1567,6 +1841,14 @@ class memory_mapped_file_t { } }; +/** + * @brief Metadata header for the serialized index. + * + * This structure is very minimalistic by design. It contains no information + * about the capacity of the index, so you'll have to `reserve` after loading. + * It also contains no info on the metric or key types, so you'll have to store + * that information elsewhere, like we do in `index_dense_head_t`. + */ struct index_serialized_header_t { std::uint64_t size = 0; std::uint64_t connectivity = 0; @@ -1612,8 +1894,8 @@ template inline key_at get_key(member_ref_gt const& m) * be seen as a network of keys, accelerating approximate @b Value~>Key visited_members. * * Unlike most implementations, this one is generic anc can be used for any search, - * not just within equi-dimensional vectors. Examples range from texts to similar Chess - * positions. + * not just within equi-dimensional vectors. Examples range from Texts to similar Chess + * positions, Geo-Spatial Search, and even Graphs. * * @tparam key_at * The type of primary objects stored in the index. @@ -1699,6 +1981,7 @@ class index_gt { using dynamic_allocator_t = dynamic_allocator_at; using tape_allocator_t = tape_allocator_at; static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); + static_assert(std::is_signed::value, "Distance must be a signed type, as we use the unary minus."); using member_cref_t = member_cref_gt; using member_ref_t = member_ref_gt; @@ -1709,10 +1992,17 @@ class index_gt { friend class index_gt; member_iterator_gt() noexcept {} - member_iterator_gt(index_t* index, std::size_t slot) noexcept : index_(index), slot_(slot) {} + member_iterator_gt(index_t* index, compressed_slot_t slot) noexcept : index_(index), slot_(slot) {} + + template ref_t call_key(std::true_type) const noexcept { + return ref_t{index_->node_at_(slot_).ckey(), slot_}; + } + template ref_t call_key(std::false_type) const noexcept { + return ref_t{index_->node_at_(slot_).key(), slot_}; + } index_t* index_{}; - std::size_t slot_{}; + compressed_slot_t slot_{}; public: using iterator_category = std::random_access_iterator_tag; @@ -1721,22 +2011,21 @@ class index_gt { using pointer = void; using reference = ref_t; - reference operator*() const noexcept { return {index_->node_at_(slot_).key(), slot_}; } - vector_key_t key() const noexcept { return index_->node_at_(slot_).key(); } + reference operator*() const noexcept { return call_key<0>(std::is_const()); } + vector_key_t key() const noexcept { return index_->node_at_(slot_).ckey(); } - friend inline std::size_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } + friend inline compressed_slot_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } - member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, slot_ + 1); } - member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, slot_ - 1); } - member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, slot_ + d); } - member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, slot_ - d); } - // clang-format off - member_iterator_gt& operator++() noexcept { slot_ += 1; return *this; } - member_iterator_gt& operator--() noexcept { slot_ -= 1; return *this; } - member_iterator_gt& operator+=(difference_type d) noexcept { slot_ += d; return *this; } - member_iterator_gt& operator-=(difference_type d) noexcept { slot_ -= d; return *this; } + member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) + 1)); } + member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) - 1)); } + member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) + d)); } + member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) - d)); } + member_iterator_gt& operator++() noexcept { slot_ = static_cast(static_cast(slot_) + 1); return *this; } + member_iterator_gt& operator--() noexcept { slot_ = static_cast(static_cast(slot_) - 1); return *this; } + member_iterator_gt& operator+=(difference_type d) noexcept { slot_ = static_cast(static_cast(slot_) + d); return *this; } + member_iterator_gt& operator-=(difference_type d) noexcept { slot_ = static_cast(static_cast(slot_) - d); return *this; } bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } // clang-format on @@ -1827,8 +2116,10 @@ class index_gt { node_t& operator=(node_t const&) = default; misaligned_ref_gt ckey() const noexcept { return {tape_}; } - misaligned_ref_gt key() const noexcept { return {tape_}; } - misaligned_ref_gt level() const noexcept { return {tape_ + sizeof(vector_key_t)}; } + misaligned_ref_gt ckey() noexcept { return {tape_}; } + misaligned_ref_gt key() const noexcept { return {tape_}; } + misaligned_ref_gt key() noexcept { return {tape_}; } + misaligned_ref_gt level() noexcept { return {tape_ + sizeof(vector_key_t)}; } void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } @@ -1846,16 +2137,22 @@ class index_gt { class neighbors_ref_t { byte_t* tape_; - static constexpr std::size_t shift(std::size_t i = 0) { + static constexpr std::size_t shift(std::size_t i = 0) noexcept { return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; } public: + using iterator = misaligned_ptr_gt; + using const_iterator = misaligned_ptr_gt; + using value_type = compressed_slot_t; + neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } misaligned_ptr_gt end() noexcept { return begin() + size(); } misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } misaligned_ptr_gt end() const noexcept { return begin() + size(); } + misaligned_ptr_gt cbegin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt cend() noexcept { return cbegin() + size(); } compressed_slot_t operator[](std::size_t i) const noexcept { return misaligned_load(tape_ + shift(i)); } @@ -1863,7 +2160,7 @@ class index_gt { void clear() noexcept { neighbors_count_t n = misaligned_load(tape_); std::memset(tape_, 0, shift(n)); - // misaligned_store(tape_, 0); + misaligned_store(tape_, 0); } void push_back(compressed_slot_t slot) noexcept { neighbors_count_t n = misaligned_load(tape_); @@ -1883,29 +2180,55 @@ class index_gt { visits_hash_set_t visits{}; std::default_random_engine level_generator{}; std::size_t iteration_cycles{}; - std::size_t computed_distances_count{}; + std::size_t computed_distances{}; + std::size_t computed_distances_in_refines{}; + std::size_t computed_distances_in_reverse_refines{}; + /// @brief Heterogeneous distance calculation. template // inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { static_assert( // std::is_same::value || std::is_same::value, "Unexpected type"); - computed_distances_count++; + computed_distances++; return metric(first, second); } + /// @brief Homogeneous distance calculation. template // inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { static_assert( // std::is_same::value || std::is_same::value, "Unexpected type"); - computed_distances_count++; + computed_distances++; return metric(first, second); } + + /// @brief Heterogeneous batch distance calculation. + template // + inline void measure_batch(value_at const& first, entries_at const& second_entries, metric_at&& metric, + candidate_allowed_at&& candidate_allowed, transform_at&& transform, + callback_at&& callback) noexcept { + + using entry_t = typename std::remove_reference::type; + metric.batch(first, second_entries, candidate_allowed, transform, + [&](entry_t const& entry, distance_t distance) { + callback(entry, distance); + computed_distances++; + }); + } }; + /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. + mutable std::atomic nodes_capacity_{}; + + /// @brief Number of "slots" already storing non-null nodes. + mutable std::atomic nodes_count_{}; + index_config_t config_{}; index_limits_t limits_{}; @@ -1915,12 +2238,6 @@ class index_gt { precomputed_constants_t pre_{}; memory_mapped_file_t viewed_file_{}; - /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. - usearch_align_m mutable std::atomic nodes_capacity_{}; - - /// @brief Number of "slots" already storing non-null nodes. - usearch_align_m mutable std::atomic nodes_count_{}; - /// @brief Controls access to `max_level_` and `entry_slot_`. /// If any thread is updating those values, no other threads can `add()` or `search()`. std::mutex global_mutex_{}; @@ -1933,7 +2250,7 @@ class index_gt { using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - /// @brief C-style array of `node_t` smart-pointers. + /// @brief C-style array of `node_t` smart-pointers. Use `compressed_slot_t` for indexing. buffer_gt nodes_{}; /// @brief Mutex, that limits concurrent access to `nodes_`. @@ -1952,17 +2269,33 @@ class index_gt { index_config_t const& config() const noexcept { return config_; } index_limits_t const& limits() const noexcept { return limits_; } bool is_immutable() const noexcept { return bool(viewed_file_); } + explicit operator bool() const noexcept { return config_.is_valid(); } /** + * @brief Default index constructor, suitable only for stateless allocators. + * @warning Consider `index_gt::make` instead, or explicitly convert to `bool` to check if the index is valid. * @section Exceptions - * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. + * Doesn't throw, unless the ::dynamic_allocator's and ::tape_allocator's throw on move-construction. */ explicit index_gt( // - index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, - tape_allocator_t tape_allocator = {}) noexcept - : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), - tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), - entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + dynamic_allocator_t dynamic_allocator = {}, tape_allocator_t tape_allocator = {}) noexcept(false) + : nodes_capacity_(0u), nodes_count_(0u), config_(), limits_(0, 0), + dynamic_allocator_(std::move(dynamic_allocator)), tape_allocator_(std::move(tape_allocator)), + pre_(precompute_({})), max_level_(-1), entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + + /** + * @brief Default index constructor, suitable only for stateless allocators. + * @warning Consider `index_gt::make` instead, or explicitly convert to `bool` to check if the index is valid. + * @section Exceptions + * Doesn't throw, unless the ::dynamic_allocator's and ::tape_allocator's throw on move-construction. + */ + explicit index_gt(index_config_t config, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept(false) + : index_gt(dynamic_allocator, tape_allocator) { + config.validate(); + config_ = config; + pre_ = precompute_(config); + } /** * @brief Clones the structure with the same hyper-parameters, but without contents. @@ -1978,17 +2311,56 @@ class index_gt { return *this; } - struct copy_result_t { - error_t error; + struct state_result_t { index_gt index; + error_t error; explicit operator bool() const noexcept { return !error; } - copy_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); + state_result_t failed(error_t message) noexcept { return {std::move(index), std::move(message)}; } + operator index_gt&&() && { + if (error) + __usearch_raise_runtime_error(error.what()); + return std::move(index); } }; + using copy_result_t = state_result_t; + /** + * @brief The recommended way to initialize the index, as unlike the constructor, + * it can fail with an error message, without raising an exception. + * + * @param[in] config The configuration specs of the index. + * @param[in] dynamic_allocator The allocator for temporary buffers and thread contexts, like priority queues. + * @param[in] tape_allocator The allocator for the primary allocations of nodes and vectors. + */ + static state_result_t make( // + index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept { + + state_result_t result; + result.error = config.validate(); + if (result.error) + return result; + + index_gt index; + index.config_ = std::move(config); + index.dynamic_allocator_ = std::move(dynamic_allocator); + index.tape_allocator_ = std::move(tape_allocator); + index.pre_ = precompute_(index.config_); + index.nodes_count_ = 0u; + index.max_level_ = -1; + index.entry_slot_ = 0u; + + result.index = std::move(index); + return result; + } + + /** + * @brief The recommended way to copy the index, as unlike the copy-constructor, + * it can fail with an error message, without raising an exception. + * + * @param[in] config The configuration specs for the copy-operation. Currently unused. + */ copy_result_t copy(index_copy_config_t config = {}) const noexcept { copy_result_t result; index_gt& other = result.index; @@ -2010,17 +2382,17 @@ class index_gt { return result; } - member_citerator_t cbegin() const noexcept { return {this, 0}; } - member_citerator_t cend() const noexcept { return {this, size()}; } - member_citerator_t begin() const noexcept { return {this, 0}; } - member_citerator_t end() const noexcept { return {this, size()}; } - member_iterator_t begin() noexcept { return {this, 0}; } - member_iterator_t end() noexcept { return {this, size()}; } + member_citerator_t cbegin() const noexcept { return {this, static_cast(0u)}; } + member_citerator_t cend() const noexcept { return {this, static_cast(size())}; } + member_citerator_t begin() const noexcept { return {this, static_cast(0u)}; } + member_citerator_t end() const noexcept { return {this, static_cast(size())}; } + member_iterator_t begin() noexcept { return {this, static_cast(0u)}; } + member_iterator_t end() noexcept { return {this, static_cast(size())}; } - member_ref_t at(std::size_t slot) noexcept { return {nodes_[slot].key(), slot}; } - member_cref_t at(std::size_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } - member_iterator_t iterator_at(std::size_t slot) noexcept { return {this, slot}; } - member_citerator_t citerator_at(std::size_t slot) const noexcept { return {this, slot}; } + member_ref_t at(compressed_slot_t slot) noexcept { return {nodes_[slot].key(), slot}; } + member_cref_t at(compressed_slot_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } + member_iterator_t iterator_at(compressed_slot_t slot) noexcept { return {this, slot}; } + member_citerator_t citerator_at(compressed_slot_t slot) const noexcept { return {this, slot}; } dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } @@ -2095,13 +2467,20 @@ class index_gt { * @brief Increases the `capacity()` of the index to allow adding more vectors. * @return `true` on success, `false` on memory allocation errors. */ - bool reserve(index_limits_t limits) usearch_noexcept_m { + bool try_reserve(index_limits_t limits) usearch_noexcept_m { if (limits.threads_add <= limits_.threads_add // && limits.threads_search <= limits_.threads_search // && limits.members <= limits_.members) return true; + // In some cases, we don't want to update the number of members, + // just want to make sure that future reserves use the new thread limits. + if (!limits.members && !size()) { + limits_ = limits; + return true; + } + nodes_mutexes_t new_mutexes(limits.members); buffer_gt new_nodes(limits.members); buffer_gt new_contexts(limits.threads()); @@ -2120,6 +2499,13 @@ class index_gt { return true; } + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @warning Unlike STL, won't throw exceptions on memory allocations, so check the return value. + * @return `true` on success, `false` on memory allocation errors. + */ + bool reserve(index_limits_t limits) usearch_noexcept_m { return try_reserve(limits); } + #if defined(USEARCH_USE_PRAGMA_REGION) #pragma endregion @@ -2131,7 +2517,9 @@ class index_gt { std::size_t new_size{}; std::size_t visited_members{}; std::size_t computed_distances{}; - std::size_t slot{}; + std::size_t computed_distances_in_refines{}; + std::size_t computed_distances_in_reverse_refines{}; + compressed_slot_t slot{}; explicit operator bool() const noexcept { return !error; } add_result_t failed(error_t message) noexcept { @@ -2176,15 +2564,15 @@ class index_gt { top_candidates_t const* top_{}; friend class index_gt; - inline search_result_t(index_gt const& index, top_candidates_t& top) noexcept - : nodes_(index.nodes_), top_(&top) {} + inline search_result_t(index_gt const& index, top_candidates_t const* top) noexcept + : nodes_(index.nodes_), top_(top) {} public: - /** @brief Number of search results found. */ + /** @brief Number of search results found. */ std::size_t count{}; - /** @brief Number of graph nodes traversed. */ + /** @brief Number of graph nodes traversed. */ std::size_t visited_members{}; - /** @brief Number of times the distances were computed. */ + /** @brief Number of times the distances were computed. */ std::size_t computed_distances{}; error_t error{}; @@ -2216,6 +2604,16 @@ class index_gt { node_t node = nodes_[candidate.slot]; return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; } + + /** + * @brief Extracts the search results into a user-provided buffer, that unlike `dump_to`, + * may already contain some data, so the new and old results are merged together. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] distances The buffer to store the distances to the search results. + * @param[in] old_count The number of results already stored in the buffers. + * @param[in] max_count The maximum number of results that can be stored in the buffers. + */ inline std::size_t merge_into( // vector_key_t* keys, distance_t* distances, // std::size_t old_count, std::size_t max_count) const noexcept { @@ -2237,6 +2635,13 @@ class index_gt { } return merged_count; } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] distances The buffer to store the distances to the search results. + */ inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { for (std::size_t i = 0; i != count; ++i) { match_t result = operator[](i); @@ -2245,6 +2650,12 @@ class index_gt { } return count; } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + */ inline std::size_t dump_to(vector_key_t* keys) const noexcept { for (std::size_t i = 0; i != count; ++i) { match_t result = operator[](i); @@ -2318,57 +2729,76 @@ class index_gt { // Determining how much memory to allocate for the node depends on the target level std::unique_lock new_level_lock(global_mutex_); - level_t max_level_copy = max_level_; // Copy under lock - std::size_t entry_idx_copy = entry_slot_; // Copy under lock - level_t target_level = choose_random_level_(context.level_generator); + level_t max_level_copy = max_level_; // Copy under lock + compressed_slot_t entry_slot_copy = static_cast(entry_slot_); // Copy under lock + level_t new_target_level = choose_random_level_(context.level_generator); // Make sure we are not overflowing std::size_t capacity = nodes_capacity_.load(); - std::size_t new_slot = nodes_count_.fetch_add(1); - if (new_slot >= capacity) { + std::size_t old_size = nodes_count_.fetch_add(1); + if (old_size >= capacity) { nodes_count_.fetch_sub(1); return result.failed("Reserve capacity ahead of insertions!"); } // Allocate the neighbors - node_t node = node_make_(key, target_level); - if (!node) { + node_t new_node = node_make_(key, new_target_level); + if (!new_node) { nodes_count_.fetch_sub(1); return result.failed("Out of memory!"); } - if (target_level <= max_level_copy) + if (new_target_level <= max_level_copy) new_level_lock.unlock(); - nodes_[new_slot] = node; - result.new_size = new_slot + 1; - result.slot = new_slot; - callback(at(new_slot)); - node_lock_t new_lock = node_lock_(new_slot); + nodes_[old_size] = new_node; + result.new_size = old_size + 1; + compressed_slot_t new_slot = result.slot = static_cast(old_size); + callback(at(result.slot)); // Do nothing for the first element - if (!new_slot) { - entry_slot_ = new_slot; - max_level_ = target_level; + if (!old_size) { + entry_slot_ = result.slot; + max_level_ = new_target_level; return result; } // Pull stats - result.computed_distances = context.computed_distances_count; + result.computed_distances = context.computed_distances; + result.computed_distances_in_refines = context.computed_distances_in_refines; + result.computed_distances_in_reverse_refines = context.computed_distances_in_reverse_refines; result.visited_members = context.iteration_cycles; - connect_node_across_levels_( // - value, metric, prefetch, // - new_slot, entry_idx_copy, max_level_copy, target_level, // - config, context); + // Go down the level, tracking only the closest match + compressed_slot_t closest_slot = search_for_one_( // + value, metric, prefetch, // + entry_slot_copy, max_level_copy, new_target_level, context); + + // From `new_target_level` down - perform proper extensive search + for (level_t level = (std::min)(new_target_level, max_level_copy); level >= 0; --level) { + // TODO: Handle out of memory conditions + search_to_insert_(value, metric, prefetch, closest_slot, level, config.expansion, context); + candidates_view_t closest_view; + { + node_lock_t new_lock = node_lock_(new_slot); + neighbors_(new_node, level).clear(); + closest_view = form_links_to_closest_(metric, new_slot, level, context); + closest_slot = closest_view[0].slot; + } + form_reverse_links_(metric, new_slot, closest_view, value, level, context); + } // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; + result.computed_distances = context.computed_distances - result.computed_distances; + result.computed_distances_in_refines = + context.computed_distances_in_refines - result.computed_distances_in_refines; + result.computed_distances_in_reverse_refines = + context.computed_distances_in_reverse_refines - result.computed_distances_in_reverse_refines; result.visited_members = context.iteration_cycles - result.visited_members; // Updating the entry point if needed - if (target_level > max_level_copy) { + if (new_target_level > max_level_copy) { entry_slot_ = new_slot; - max_level_ = target_level; + max_level_ = new_target_level; } return result; } @@ -2376,6 +2806,10 @@ class index_gt { /** * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. * + * ! It's assumed that different threads aren't updating the same entry at the same time. + * ! The state won't be corrupted, but no transactional guarantees are provided and the + * ! resulting value & neighbors list may be inconsistent. + * * @tparam metric_at * A function responsible for computing the distance @b (dis-similarity) between two objects. * It should be callable into distinctly different scenarios: @@ -2407,9 +2841,13 @@ class index_gt { callback_at&& callback = callback_at{}, // prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!config.expansion) + config.expansion = default_expansion_add(); + usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); add_result_t result; - std::size_t old_slot = iterator.slot_; + compressed_slot_t updated_slot = iterator.slot_; // Make sure we have enough local memory to perform this request context_t& context = contexts_[config.thread]; @@ -2427,30 +2865,60 @@ class index_gt { if (!next.reserve(config.expansion)) return result.failed("Out of memory!"); - node_lock_t new_lock = node_lock_(old_slot); - node_t node = node_at_(old_slot); + node_t updated_node = node_at_(updated_slot); + level_t updated_node_level = updated_node.level(); - level_t node_level = node.level(); - span_bytes_t node_bytes = node_bytes_(node); - std::memset(node_bytes.data(), 0, node_bytes.size()); - node.level(node_level); + // Copy entry coordinates under locks + level_t max_level_copy; + compressed_slot_t entry_slot_copy; + { + std::unique_lock new_level_lock(global_mutex_); + max_level_copy = max_level_; // Copy under lock + entry_slot_copy = static_cast(entry_slot_); // Copy under lock + } // Pull stats - result.computed_distances = context.computed_distances_count; + result.computed_distances = context.computed_distances; result.visited_members = context.iteration_cycles; - connect_node_across_levels_( // - value, metric, prefetch, // - old_slot, entry_slot_, max_level_, node_level, // - config, context); - node.key(key); + // Go down the level, tracking only the closest match; + // It may even be equal to the `updated_slot` + compressed_slot_t closest_slot = + // If we are updating the entry node itself, it won't contain any neighbors, + // so we should traverse a level down to find the closest match. + updated_node_level == max_level_copy // + ? entry_slot_copy + : search_for_one_( // + value, metric, prefetch, // + entry_slot_copy, max_level_copy, updated_node_level, context); + + // From `updated_node_level` down - perform proper extensive search + for (level_t level = (std::min)(updated_node_level, max_level_copy); level >= 0; --level) { + if (!search_to_update_(value, metric, prefetch, closest_slot, updated_slot, level, config.expansion, + context)) + return result.failed("Out of memory!"); + + candidates_view_t closest_view; + { + node_lock_t updated_lock = node_lock_(updated_slot); + // TODO: Go through existing neighbors removing reverse links + // for (compressed_slot_t slot : neighbors_(updated_node, level)) + // remove_link_(slot, updated_slot, level); + neighbors_(updated_node, level).clear(); + closest_view = form_links_to_closest_(metric, updated_slot, level, context); + if (closest_view.size()) + closest_slot = closest_view[0].slot; + } + form_reverse_links_(metric, updated_slot, closest_view, value, level, context); + } + updated_node.key(key); // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; + result.computed_distances = context.computed_distances - result.computed_distances; result.visited_members = context.iteration_cycles - result.visited_members; - result.slot = old_slot; + result.slot = updated_slot; - callback(at(old_slot)); + callback(at(updated_slot)); return result; } @@ -2475,16 +2943,28 @@ class index_gt { metric_at&& metric, // index_search_config_t config = {}, // predicate_at&& predicate = predicate_at{}, // - prefetch_at&& prefetch = prefetch_at{}) const noexcept { + prefetch_at&& prefetch = prefetch_at{}) const usearch_noexcept_m { - context_t& context = contexts_[config.thread]; - top_candidates_t& top = context.top_candidates; - search_result_t result{*this, top}; - if (!nodes_count_) + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!wanted) + return search_result_t{}; + + // Expansion factor set to zero is equivalent to the default value + if (!config.expansion) + config.expansion = default_expansion_search(); + + // Using references is cleaner, but would result in UBSan false positives + context_t* context_ptr = contexts_.data() ? contexts_.data() + config.thread : nullptr; + top_candidates_t* top_ptr = context_ptr ? &context_ptr->top_candidates : nullptr; + search_result_t result{*this, top_ptr}; + if (!nodes_count_.load(std::memory_order_relaxed)) return result; + usearch_assert_m(contexts_.size() > config.thread, "Thread index out of bounds"); + context_t& context = *context_ptr; + top_candidates_t& top = *top_ptr; // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; + result.computed_distances = context.computed_distances; result.visited_members = context.iteration_cycles; if (config.exact) { @@ -2494,12 +2974,14 @@ class index_gt { } else { next_candidates_t& next = context.next_candidates; std::size_t expansion = (std::max)(config.expansion, wanted); + usearch_assert_m(expansion > 0, "Expansion factor can't be a zero!"); if (!next.reserve(expansion)) return result.failed("Out of memory!"); if (!top.reserve(expansion)) return result.failed("Out of memory!"); - std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); + compressed_slot_t closest_slot = search_for_one_( + query, metric, prefetch, static_cast(entry_slot_), max_level_, 0, context); // For bottom layer we need a more optimized procedure if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) @@ -2510,7 +2992,7 @@ class index_gt { top.shrink(wanted); // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; + result.computed_distances = context.computed_distances - result.computed_distances; result.visited_members = context.iteration_cycles - result.visited_members; result.count = top.size(); return result; @@ -2545,7 +3027,7 @@ class index_gt { return result.failed("No clusters to identify"); // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; + result.computed_distances = context.computed_distances; result.visited_members = context.iteration_cycles; next_candidates_t& next = context.next_candidates; @@ -2553,12 +3035,13 @@ class index_gt { if (!next.reserve(expansion)) return result.failed("Out of memory!"); - result.cluster.member = at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, - static_cast(level - 1), context)); + result.cluster.member = + at(search_for_one_(query, metric, prefetch, static_cast(entry_slot_), max_level_, + static_cast(level <= 0 ? 0 : level - 1), context)); result.cluster.distance = context.measure(query, result.cluster.member, metric); // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; + result.computed_distances = context.computed_distances - result.computed_distances; result.visited_members = context.iteration_cycles - result.visited_members; (void)predicate; @@ -2578,6 +3061,9 @@ class index_gt { std::size_t allocated_bytes{}; }; + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage across all levels. + */ stats_t stats() const noexcept { stats_t result{}; @@ -2596,10 +3082,17 @@ class index_gt { return result; } + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage up to a specific level. + * + * The `level` parameter is zero-based, where `0` is the base level. + * For example, `level=1` will include the base level and the first level of connections. + */ stats_t stats(std::size_t level) const noexcept { stats_t result{}; - std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; + std::size_t max_edges_per_node = !level ? config_.connectivity_base : config_.connectivity; + for (std::size_t i = 0; i != size(); ++i) { node_t node = node_at_(i); if (static_cast(node.level()) < level) @@ -2610,11 +3103,17 @@ class index_gt { result.allocated_bytes += node_head_bytes_() + neighbors_bytes; } - std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; result.max_edges = result.nodes * max_edges_per_node; return result; } + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage up to a specific level, + * simultaneously exporting the stats for each level into the `stats_per_level` C-style array. + * + * The `max_level` parameter is zero-based, where `0` is the base level. + * For example, `max_level=1` will include the base level and the first level of connections. + */ stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { std::size_t head_bytes = node_head_bytes_(); @@ -2744,6 +3243,7 @@ class index_gt { serialization_result_t result; // Remove previously stored objects + index_limits_t old_limits = limits_; reset(); // Pull basic metadata @@ -2768,9 +3268,15 @@ class index_gt { // Submit metadata config_.connectivity = header.connectivity; config_.connectivity_base = header.connectivity_base; + error_t error = config_.validate(); + if (error) + return result.failed(std::move(error)); + pre_ = precompute_(config_); index_limits_t limits; limits.members = header.size; + limits.threads_add = (std::max)(1, old_limits.threads_add); + limits.threads_search = (std::max)(1, old_limits.threads_search); if (!reserve(limits)) { reset(); return result.failed("Out of memory"); @@ -2820,8 +3326,12 @@ class index_gt { }, std::forward(progress)); - if (!stream_result) + if (!stream_result) { + // Drop generic messages like "end of file reached" in favor + // of more specific messages from the stream + io_result.error.release(); return stream_result; + } return io_result; } @@ -2869,8 +3379,12 @@ class index_gt { }, std::forward(progress)); - if (!stream_result) + if (!stream_result) { + // Drop generic messages like "end of file reached" in favor + // of more specific messages from the stream + io_result.error.release(); return stream_result; + } return io_result; } @@ -2909,6 +3423,7 @@ class index_gt { progress_at&& progress = {}) noexcept { // Remove previously stored objects + index_limits_t old_limits = limits_; reset(); serialization_result_t result = file.open_if_not(); @@ -2935,6 +3450,10 @@ class index_gt { config_.connectivity = header.connectivity; config_.connectivity_base = header.connectivity_base; + error_t error = config_.validate(); + if (error) + return result.failed(std::move(error)); + pre_ = precompute_(config_); misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; @@ -2950,6 +3469,8 @@ class index_gt { // Submit metadata and reserve memory index_limits_t limits; limits.members = header.size; + limits.threads_add = (std::max)(1, old_limits.threads_add); + limits.threads_search = (std::max)(1, old_limits.threads_search); if (!reserve(limits)) { reset(); return result.failed("Out of memory"); @@ -2977,16 +3498,12 @@ class index_gt { * and links to them, while also generating a more efficient mapping, * putting the more frequently used entries closer together. * - * - * Scans the whole collection, removing the links leading towards - * banned entries. This essentially isolates some nodes from the rest - * of the graph, while keeping their outgoing links, in case the node - * is structurally relevant and has a crucial role in the index. - * It won't reclaim the memory. - * - * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] values A []-subscriptable object, providing access to the values. + * @param[in] metric Callable object measuring distance between any ::values and present objects. + * @param[in] slot_transition Callable object to inform changes in slot assignments. * @param[in] executor Thread-pool to execute the job in parallel. * @param[in] progress Callback to report the execution progress. + * @param[in] prefetch Callable object to prefetch data into the cache. */ template (old_slot), // - static_cast(cluster), // - node_at_(old_slot).level()}; + compressed_slot_t old_slot = static_cast(old_slot_as_uint); + compressed_slot_t cluster = search_for_one_( // + values[citerator_at(old_slot)], // + metric, prefetch, // + static_cast(entry_slot_), max_level_, 0, context); + slots_and_levels[old_slot] = {old_slot, cluster, node_at_(old_slot).level()}; ++processed; if (thread_idx == 0) do_tasks = progress(processed.load(), total); @@ -3193,6 +3708,7 @@ class index_gt { inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { + usearch_assert_m(level > 0 && level <= node.level(), "Linking to missing level"); return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; } @@ -3212,90 +3728,90 @@ class index_gt { return {nodes_mutexes_, slot}; } - template - void connect_node_across_levels_( // - value_at&& value, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // - index_update_config_t const& config, context_t& context) usearch_noexcept_m { - - // Go down the level, tracking only the closest match - std::size_t closest_slot = search_for_one_( // - value, metric, prefetch, // - entry_slot, max_level, target_level, context); - - // From `target_level` down perform proper extensive search - for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { - // TODO: Handle out of memory conditions - search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); - closest_slot = connect_new_node_(metric, node_slot, level, context); - reconnect_neighbor_nodes_(metric, node_slot, value, level, context); + struct node_conditional_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_conditional_lock_t() noexcept { + if (slot != std::numeric_limits::max()) + mutexes.atomic_reset(slot); } + }; + + inline node_conditional_lock_t node_try_conditional_lock_(std::size_t slot, bool condition, + bool& failed_to_acquire) const noexcept { + failed_to_acquire = condition ? nodes_mutexes_.atomic_set(slot) : false; + return {nodes_mutexes_, failed_to_acquire ? std::numeric_limits::max() : slot}; } - template - std::size_t connect_new_node_( // + template + candidates_view_t form_links_to_closest_( // metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { node_t new_node = node_at_(new_slot); top_candidates_t& top = context.top_candidates; + usearch_assert_m(top.size() || !require_non_empty_ak, "No candidates found"); + candidates_view_t top_view = + refine_(metric, config_.connectivity, top, context, context.computed_distances_in_refines); + usearch_assert_m(top_view.size() || !require_non_empty_ak, "This would lead to isolated nodes"); // Outgoing links from `new_slot`: neighbors_ref_t new_neighbors = neighbors_(new_node, level); - { - usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); - candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); - - for (std::size_t idx = 0; idx != top_view.size(); idx++) { - usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); - usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); - new_neighbors.push_back(top_view[idx].slot); - } + usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); + for (std::size_t idx = 0; idx != top_view.size(); idx++) { + usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); + usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); + new_neighbors.push_back(top_view[idx].slot); } - return new_neighbors[0]; + return top_view; } template - void reconnect_neighbor_nodes_( // - metric_at&& metric, std::size_t new_slot, value_at&& value, level_t level, - context_t& context) usearch_noexcept_m { + void form_reverse_links_( // + metric_at&& metric, compressed_slot_t new_slot, candidates_view_t new_neighbors, value_at&& value, + level_t level, context_t& context) usearch_noexcept_m { - node_t new_node = node_at_(new_slot); top_candidates_t& top = context.top_candidates; - neighbors_ref_t new_neighbors = neighbors_(new_node, level); + std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; // Reverse links from the neighbors: - std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; - for (compressed_slot_t close_slot : new_neighbors) { + for (auto new_neighbor : new_neighbors) { + compressed_slot_t close_slot = new_neighbor.slot; if (close_slot == new_slot) continue; node_lock_t close_lock = node_lock_(close_slot); node_t close_node = node_at_(close_slot); - neighbors_ref_t close_header = neighbors_(close_node, level); - usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); + + // The node may have no neighbors only in one case, when it's the first one in the index, + // but that is problematic to track in multi-threaded environments, where the order of insertion + // is not guaranteed. + // usearch_assert_m(close_header.size() || new_slot == 1, "Possible corruption - isolated node"); + usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption - overflow"); usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); usearch_assert_m(level <= close_node.level(), "Linking to missing level"); // If `new_slot` is already present in the neighboring connections of `close_slot` // then no need to modify any connections or run the heuristics. if (close_header.size() < connectivity_max) { - close_header.push_back(static_cast(new_slot)); + close_header.push_back(new_slot); continue; } // To fit a new connection we need to drop an existing one. top.clear(); - usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); - top.insert_reserved( - {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); + usearch_assert_m((top.capacity() >= (close_header.size() + 1)), + "The memory must have been reserved in `add`"); + top.insert_reserved({context.measure(value, citerator_at(close_slot), metric), new_slot}); for (compressed_slot_t successor_slot : close_header) top.insert_reserved( {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); // Export the results: close_header.clear(); - candidates_view_t top_view = refine_(metric, connectivity_max, top, context); + candidates_view_t top_view = + refine_(metric, connectivity_max, top, context, context.computed_distances_in_reverse_refines); + usearch_assert_m(top_view.size(), "This would lead to isolated nodes"); for (std::size_t idx = 0; idx != top_view.size(); idx++) close_header.push_back(top_view[idx].slot); } @@ -3337,12 +3853,12 @@ class index_gt { using pointer = misaligned_ptr_gt; using reference = misaligned_ref_gt; - reference operator*() const noexcept { return slot(); } + value_type operator*() const noexcept { return neighbors_[current_]; } candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, std::size_t progress) noexcept : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} candidates_iterator_t operator++(int) noexcept { - return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); + return candidates_iterator_t(index_, neighbors_, visits_, current_ + 1).skip_missing(); } candidates_iterator_t& operator++() noexcept { ++current_; @@ -3352,7 +3868,7 @@ class index_gt { bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } - vector_key_t key() const noexcept { return index_->node_at_(slot()).key(); } + vector_key_t key() const noexcept { return index_.node_at_(slot()).key(); } compressed_slot_t slot() const noexcept { return neighbors_[current_]; } friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } @@ -3370,16 +3886,16 @@ class index_gt { }; template - std::size_t search_for_one_( // + compressed_slot_t search_for_one_( // value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { + compressed_slot_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { visits_hash_set_t& visits = context.visits; visits.clear(); // Optional prefetching if (!is_dummy()) - prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); + prefetch(citerator_at(closest_slot), citerator_at(closest_slot) + 1); distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); for (level_t level = begin_level; level > end_level; --level) { @@ -3404,6 +3920,7 @@ class index_gt { changed = true; } } + context.iteration_cycles++; } while (changed); } @@ -3418,8 +3935,7 @@ class index_gt { template bool search_to_insert_( // value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, - context_t& context) noexcept { + compressed_slot_t start_slot, level_t level, std::size_t top_limit, context_t& context) noexcept { visits_hash_set_t& visits = context.visits; next_candidates_t& next = context.next_candidates; // pop min, push @@ -3428,18 +3944,21 @@ class index_gt { visits.clear(); next.clear(); top.clear(); + + // At the very least we are going to explore the starting node and its neighbors if (!visits.reserve(config_.connectivity_base + 1u)) return false; // Optional prefetching if (!is_dummy()) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - top.insert_reserved({radius, static_cast(start_slot)}); - visits.set(static_cast(start_slot)); + next.insert_reserved({-radius, start_slot}); + top.insert_reserved({radius, start_slot}); + visits.set(start_slot); + // The primary loop of the graph traversal while (!next.empty()) { candidate_t candidacy = next.top(); @@ -3450,8 +3969,6 @@ class index_gt { context.iteration_cycles++; compressed_slot_t candidate_slot = candidacy.slot; - if (new_slot == candidate_slot) - continue; node_t candidate_ref = node_at_(candidate_slot); node_lock_t candidate_lock = node_lock_(candidate_slot); neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); @@ -3470,6 +3987,8 @@ class index_gt { if (visits.set(successor_slot)) continue; + // We don't access the neighbors of the `successor_slot` node, + // so we don't have to lock it. // node_lock_t successor_lock = node_lock_(successor_slot); distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); if (top.size() < top_limit || successor_dist < radius) { @@ -3484,6 +4003,95 @@ class index_gt { return true; } + /** + * @brief Traverses a layer of a graph, to find the best neighbors list for updated node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_update_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + compressed_slot_t start_slot, compressed_slot_t updated_slot, level_t level, std::size_t top_limit, + context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + + // At the very least we are going to explore the starting node and its neighbors + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, start_slot}); + visits.set(start_slot); + if (start_slot != updated_slot) + top.insert_reserved({radius, start_slot}); + + // The primary loop of the graph traversal + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + node_t candidate_ref = node_at_(candidate_slot); + + // The trickiest part of update-heavy workloads is mitigating dead-locks + // in connected nodes during traversal. A "good enough" solution would be + // to skip concurrent access, assuming the other "close" node is gonna add + // this one when forming reverse connections. + bool failed_to_acquire = false; + node_conditional_lock_t candidate_lock = + node_try_conditional_lock_(candidate_slot, updated_slot != candidate_slot, failed_to_acquire); + if (failed_to_acquire) + continue; + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // We don't access the neighbors of the `successor_slot` node, + // so we don't have to lock it. + // node_conditional_lock_t successor_lock = + // node_try_conditional_lock_(successor_slot, updated_slot != successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + if (updated_slot != successor_slot) + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + /** * @brief Traverses the @b base layer of a graph, to find a close match. * Doesn't lock any nodes, assuming read-only simultaneous access. @@ -3492,7 +4100,7 @@ class index_gt { template bool search_to_find_in_base_( // value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // - std::size_t start_slot, std::size_t expansion, context_t& context) const noexcept { + compressed_slot_t start_slot, std::size_t expansion, context_t& context) const usearch_noexcept_m { visits_hash_set_t& visits = context.visits; next_candidates_t& next = context.next_candidates; // pop min, push @@ -3507,20 +4115,24 @@ class index_gt { // Optional prefetching if (!is_dummy()) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - visits.set(static_cast(start_slot)); + usearch_assert_m(next.capacity(), "The `max_heap_gt` must have been reserved in the search entry point"); + next.insert_reserved({-radius, start_slot}); + visits.set(start_slot); // Don't populate the top list if the predicate is not satisfied - if (is_dummy() || predicate(member_cref_t{node_at_(start_slot).ckey(), start_slot})) - top.insert_reserved({radius, static_cast(start_slot)}); + if (is_dummy() || predicate(member_cref_t{node_at_(start_slot).ckey(), start_slot})) { + usearch_assert_m(top.capacity(), + "The `sorted_buffer_gt` must have been reserved in the search entry point"); + top.insert_reserved({radius, start_slot}); + } while (!next.empty()) { candidate_t candidate = next.top(); - if ((-candidate.distance) > radius) + if ((-candidate.distance) > radius && top.size() == top_limit) break; next.pop(); @@ -3569,12 +4181,13 @@ class index_gt { top.clear(); top.reserve(count); for (std::size_t i = 0; i != size(); ++i) { + auto slot = static_cast(i); if (!is_dummy()) - if (!predicate(at(i))) + if (!predicate(at(slot))) continue; - distance_t distance = context.measure(query, citerator_at(i), metric); - top.insert(candidate_t{distance, static_cast(i)}, count); + distance_t distance = context.measure(query, citerator_at(slot), metric); + top.insert(candidate_t{distance, slot}, count); } } @@ -3584,22 +4197,27 @@ class index_gt { * to keep only the neighbors, that are from each other. */ template - candidates_view_t refine_( // - metric_at&& metric, // - std::size_t needed, top_candidates_t& top, context_t& context) const noexcept { + candidates_view_t refine_( // + metric_at&& metric, // + std::size_t needed, top_candidates_t& top, context_t& context, // + std::size_t& refines_counter) const noexcept { - top.sort_ascending(); + // Avoid expensive computation, if the set is already small candidate_t* top_data = top.data(); std::size_t const top_count = top.size(); if (top_count < needed) return {top_data, top_count}; + // Sort before processing + top.sort_ascending(); + std::size_t submitted_count = 1; std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. while (submitted_count < needed && consumed_count < top_count) { candidate_t candidate = top_data[consumed_count]; bool good = true; - for (std::size_t idx = 0; idx < submitted_count; idx++) { + std::size_t idx = 0; + for (; idx < submitted_count; idx++) { candidate_t submitted = top_data[idx]; distance_t inter_result_dist = context.measure( // citerator_at(candidate.slot), // @@ -3610,6 +4228,7 @@ class index_gt { break; } } + refines_counter += idx; if (good) { top_data[submitted_count] = top_data[consumed_count]; @@ -3791,11 +4410,12 @@ static join_result_t join( // bool woman_is_free = husband_slot == missing_slot; if (woman_is_free) { // Engagement - man_to_woman_slots[free_man_slot] = woman.slot; + man_to_woman_slots[free_man_slot] = static_cast(woman.slot); woman_to_man_slots[woman.slot] = free_man_slot; engagements++; } else { - distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); + distance_t distance_from_husband = + women_metric(women_values[static_cast(woman.slot)], men_values[husband_slot]); distance_t distance_from_candidate = match.distance; if (distance_from_husband > distance_from_candidate) { // Break-up @@ -3805,7 +4425,7 @@ static join_result_t join( // men_locks.atomic_reset(husband_slot); // New Engagement - man_to_woman_slots[free_man_slot] = woman.slot; + man_to_woman_slots[free_man_slot] = static_cast(woman.slot); woman_to_man_slots[woman.slot] = free_man_slot; engagements++; @@ -3830,7 +4450,7 @@ static join_result_t join( // for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; if (woman_slot != missing_slot) { - man_key_t man = men.at(man_slot).key; + man_key_t man = men.at(static_cast(man_slot)).key; woman_key_t woman = women.at(woman_slot).key; man_to_woman[man] = woman; woman_to_man[woman] = man; @@ -3850,3 +4470,9 @@ static join_result_t join( // } // namespace unum #endif + +// This file is part of the usearch inline third-party dependency of YugabyteDB. +// Git repo: https://github.com/unum-cloud/usearch +// Git commit: 240fe9c298100f9e37a2d7377b1595be6ba1f412 +// +// See also src/inline-thirdparty/README.md. diff --git a/src/inline-thirdparty/usearch/usearch/index_dense.hpp b/src/inline-thirdparty/usearch/usearch/index_dense.hpp index 58851829d691..b294f0fa2bb4 100644 --- a/src/inline-thirdparty/usearch/usearch/index_dense.hpp +++ b/src/inline-thirdparty/usearch/usearch/index_dense.hpp @@ -1,11 +1,13 @@ +/** + * @file index_dense.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine for equi-dimensional dense vectors. + * @date July 26, 2023 + */ #pragma once +#include "index.hpp" #include // `aligned_alloc` -#include // `std::function` -#include // `std::iota` -#include // `std::thread` -#include // `std::vector` - #include #include @@ -89,10 +91,32 @@ struct index_dense_head_result_t { } }; +/** + * @brief Configuration settings for the construction of dense + * equidimensional vector indexes. + * + * Unlike the underlying `index_gt` class, incorporates the + * `::expansion_add` and `::expansion_search` parameters passed + * separately for the lower-level engine. + */ struct index_dense_config_t : public index_config_t { std::size_t expansion_add = default_expansion_add(); std::size_t expansion_search = default_expansion_search(); + + /** + * @brief Excludes vectors from the serialized file. + * This is handy when you want to store the vectors in a separate file. + * + * ! For advanced users only. + */ bool exclude_vectors = false; + + /** + * @brief Allows you to store multiple vectors per key. + * This is handy when a large document is chunked into many parts. + * + * ! May degrade the performance of iterators. + */ bool multi = false; /** @@ -101,17 +125,37 @@ struct index_dense_config_t : public index_config_t { * the vectors-to-keys mappings. * * ! This configuration parameter doesn't affect the serialized file, - * ! and is not preserved between runs. Makes sense for small vector - * ! representations that fit ina single cache line. + * ! and is not preserved between runs. Makes sense for smaller vectors + * ! that fit in a couple of cache lines. + * + * The trade-off is that some methods won't be available, like `get`, `rename`, + * and `remove`. The basic functionality, like `add` and `search` will work as + * expected even with `enable_key_lookups = false`. + * + * If both `!multi && !enable_key_lookups`, the "duplicate entry" checks won't + * be performed and no errors will be raised. */ bool enable_key_lookups = true; - index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} + inline index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} + + inline index_dense_config_t(std::size_t c = 0, std::size_t ea = 0, std::size_t es = 0) noexcept + : index_config_t(c), expansion_add(ea), expansion_search(es) {} - index_dense_config_t(std::size_t c = default_connectivity(), std::size_t ea = default_expansion_add(), - std::size_t es = default_expansion_search()) noexcept - : index_config_t(c), expansion_add(ea ? ea : default_expansion_add()), - expansion_search(es ? es : default_expansion_search()) {} + /** + * @brief Validates the configuration settings, updating them in-place. + * @return Error message, if any. + */ + inline error_t validate() noexcept { + error_t error = index_config_t::validate(); + if (error) + return error; + if (expansion_add == 0) + expansion_add = default_expansion_add(); + if (expansion_search == 0) + expansion_search = default_expansion_search(); + return {}; + } }; struct index_dense_clustering_config_t { @@ -163,6 +207,45 @@ struct index_dense_metadata_result_t { } }; +/** + * @brief Fixes serialized scalar-kind codes for pre-v2.10 versions, until we can upgrade to v3. + * The old enum `scalar_kind_t` is defined without explicit constants from 0. + */ +inline scalar_kind_t convert_pre_2_10_scalar_kind(scalar_kind_t scalar_kind) noexcept { + switch (static_cast::type>(scalar_kind)) { + case 0: return scalar_kind_t::unknown_k; + case 1: return scalar_kind_t::b1x8_k; + case 2: return scalar_kind_t::u40_k; + case 3: return scalar_kind_t::uuid_k; + case 4: return scalar_kind_t::f64_k; + case 5: return scalar_kind_t::f32_k; + case 6: return scalar_kind_t::f16_k; + case 7: return scalar_kind_t::f8_k; + case 8: return scalar_kind_t::u64_k; + case 9: return scalar_kind_t::u32_k; + case 10: return scalar_kind_t::u8_k; + case 11: return scalar_kind_t::i64_k; + case 12: return scalar_kind_t::i32_k; + case 13: return scalar_kind_t::i16_k; + case 14: return scalar_kind_t::i8_k; + default: return scalar_kind; + } +} + +/** + * @brief Fixes the metadata for pre-v2.10 versions, until we can upgrade to v3. + * Originates from: https://github.com/unum-cloud/usearch/issues/423 + */ +inline void fix_pre_2_10_metadata(index_dense_head_t& head) { + if (head.version_major == 2 && head.version_minor < 10) { + head.kind_scalar = convert_pre_2_10_scalar_kind(head.kind_scalar); + head.kind_key = convert_pre_2_10_scalar_kind(head.kind_key); + head.kind_compressed_slot = convert_pre_2_10_scalar_kind(head.kind_compressed_slot); + head.version_minor = 10; + head.version_patch = 0; + } +} + /** * @brief Extracts metadata from a pre-constructed index on disk, * without loading it or mapping the whole binary file. @@ -180,8 +263,10 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* // Check if the file immediately starts with the index, instead of vectors result.config.exclude_vectors = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); return result; + } if (std::fseek(file.get(), 0L, SEEK_END) != 0) return result.failed("Can't infer file size"); @@ -207,8 +292,10 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* result.config.exclude_vectors = false; result.config.use_64_bit_dimensions = false; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); return result; + } } // Check if it starts with 64-bit @@ -222,8 +309,10 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* // Check if it starts with 64-bit result.config.exclude_vectors = false; result.config.use_64_bit_dimensions = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); return result; + } } return result.failed("Not a dense USearch index!"); @@ -232,7 +321,7 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* /** * @brief Extracts metadata from a pre-constructed index serialized into an in-memory buffer. */ -inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_mapped_file_t file, +inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_mapped_file_t const& file, std::size_t offset = 0) noexcept { index_dense_metadata_result_t result; @@ -240,7 +329,7 @@ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_map if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) return result.failed("End of file reached!"); - byte_t* const file_data = file.data() + offset; + byte_t const* file_data = file.data() + offset; std::size_t const file_size = file.size() - offset; std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); @@ -317,8 +406,6 @@ class index_dense_gt { using tape_allocator_t = memory_mapping_allocator_gt<64>; private: - /// @brief Schema: input buffer, bytes in input buffer, output buffer. - using cast_t = std::function; /// @brief Punned index. using index_t = index_gt< // distance_t, vector_key_t, compressed_slot_t, // @@ -353,20 +440,11 @@ class index_dense_gt { index_dense_config_t config_; index_t* typed_ = nullptr; - mutable std::vector cast_buffer_; - struct casts_t { - cast_t from_b1x8; - cast_t from_i8; - cast_t from_f16; - cast_t from_f32; - cast_t from_f64; - - cast_t to_b1x8; - cast_t to_i8; - cast_t to_f16; - cast_t to_f32; - cast_t to_f64; - } casts_; + using cast_buffer_t = buffer_gt; + + /// @brief Temporary memory for every thread to store a casted vector. + mutable cast_buffer_t cast_buffer_; + casts_punned_t casts_; /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. metric_t metric_; @@ -375,11 +453,17 @@ class index_dense_gt { /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. vectors_tape_allocator_t vectors_tape_allocator_; + using vectors_lookup_allocator_t = aligned_allocator_gt; + using vectors_lookup_t = buffer_gt; + /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. - mutable std::vector vectors_lookup_; + mutable vectors_lookup_t vectors_lookup_; + + using available_threads_allocator_t = aligned_allocator_gt; + using available_threads_t = ring_gt; - /// @brief Originally forms and array of integers [0, threads], marking all - mutable std::vector available_threads_; + /// @brief Originally forms and array of integers [0, threads], marking all as available. + mutable available_threads_t available_threads_; /// @brief Mutex, controlling concurrent access to `available_threads_`. mutable std::mutex available_threads_mutex_; @@ -402,8 +486,8 @@ class index_dense_gt { struct lookup_key_hash_t { using is_transparent = void; - std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash{}(k.key); } - std::size_t operator()(vector_key_t const& k) const noexcept { return std::hash{}(k); } + std::size_t operator()(key_and_slot_t const& k) const noexcept { return hash_gt{}(k.key); } + std::size_t operator()(vector_key_t const& k) const noexcept { return hash_gt{}(k); } }; struct lookup_key_same_t { @@ -428,13 +512,57 @@ class index_dense_gt { /// @brief A constant for the reserved key value, used to mark deleted entries. vector_key_t free_key_ = default_free_value(); + /// @brief Locks the thread for the duration of the operation. + struct thread_lock_t { + index_dense_gt const& parent; + std::size_t thread_id = 0; + bool engaged = false; + + ~thread_lock_t() usearch_noexcept_m { + if (engaged) + parent.thread_unlock_(thread_id); + } + + thread_lock_t(thread_lock_t const&) = delete; + thread_lock_t& operator=(thread_lock_t const&) = delete; + + thread_lock_t(index_dense_gt const& parent, std::size_t thread_id, bool engaged = true) noexcept + : parent(parent), thread_id(thread_id), engaged(engaged) {} + thread_lock_t(thread_lock_t&& other) noexcept + : parent(other.parent), thread_id(other.thread_id), engaged(other.engaged) { + other.engaged = false; + } + }; + public: - using search_result_t = typename index_t::search_result_t; using cluster_result_t = typename index_t::cluster_result_t; using add_result_t = typename index_t::add_result_t; using stats_t = typename index_t::stats_t; using match_t = typename index_t::match_t; + /** + * @brief A search result, containing the found keys and distances. + * + * As the `index_dense_gt` manages the thread-pool on its own, the search result + * preserves the thread-lock to avoid undefined behaviors, when other threads + * start overwriting the results. + */ + struct search_result_t : public index_t::search_result_t { + inline search_result_t(index_dense_gt const& parent) noexcept + : index_t::search_result_t(), lock_(parent, 0, false) {} + search_result_t failed(error_t message) noexcept { + this->error = std::move(message); + return std::move(*this); + } + + private: + friend class index_dense_gt; + thread_lock_t lock_; + + inline search_result_t(typename index_t::search_result_t result, thread_lock_t lock) noexcept + : index_t::search_result_t(std::move(result)), lock_(std::move(lock)) {} + }; + index_dense_gt() = default; index_dense_gt(index_dense_gt&& other) : config_(std::move(other.config_)), @@ -485,51 +613,76 @@ class index_dense_gt { typed_ = nullptr; } + struct state_result_t { + index_dense_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + state_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + operator index_dense_gt&&() && { + if (error) + __usearch_raise_runtime_error(error.what()); + return std::move(index); + } + }; + using copy_result_t = state_result_t; + /** * @brief Constructs an instance of ::index_dense_gt. * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. * @param[in] config The index configuration (optional). * @param[in] free_key The key used for freed vectors (optional). - * @return An instance of ::index_dense_gt. + * @return An instance of ::index_dense_gt or error, wrapped in a `state_result_t`. + * + * ! If the `metric` isn't provided in this method, it has to be set with + * ! the `change_metric` method before the index can be used. Alternatively, + * ! if you are loading an existing index, the metric will be set automatically. */ - static index_dense_gt make( // - metric_t metric, // + static state_result_t make( // + metric_t metric = {}, // index_dense_config_t config = {}, // vector_key_t free_key = default_free_value()) { - scalar_kind_t scalar_kind = metric.scalar_kind(); - std::size_t hardware_threads = std::thread::hardware_concurrency(); - - index_dense_gt result; - result.config_ = config; - result.cast_buffer_.resize(hardware_threads * metric.bytes_per_vector()); - result.casts_ = make_casts_(scalar_kind); - result.metric_ = metric; - result.free_key_ = free_key; - - // Fill the thread IDs. - result.available_threads_.resize(hardware_threads); - std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); - - // Available since C11, but only C++17, so we use the C version. + if (metric.missing()) + return state_result_t{}.failed("Metric won't be initialized!"); + error_t error = config.validate(); + if (error) + return state_result_t{}.failed(std::move(error)); index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return state_result_t{}.failed("Failed to allocate memory for the index!"); + + state_result_t result; + index_dense_gt& index = result.index; + index.config_ = config; + index.free_key_ = free_key; + + // In some cases the metric is not provided, and will be set later. + if (metric) { + scalar_kind_t scalar_kind = metric.scalar_kind(); + index.casts_ = casts_punned_t::make(scalar_kind); + index.metric_ = metric; + } + new (raw) index_t(config); - result.typed_ = raw; + index.typed_ = raw; return result; } - static index_dense_gt make(char const* path, bool view = false) { - index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); - if (!meta) - return {}; - metric_punned_t metric(meta.head.dimensions, meta.head.kind_metric, meta.head.kind_scalar); - index_dense_gt result = make(metric); - if (!result) - return result; - if (view) - result.view(path); - else - result.load(path); + /** + * @brief Constructs an instance of ::index_dense_gt from a serialized binary file. + * @param[in] path The path to the binary file. + * @param[in] view Whether to map the file into memory or load it. + * @return An instance of ::index_dense_gt or error, wrapped in a `state_result_t`. + */ + static state_result_t make(char const* path, bool view = false) { + state_result_t result; + serialization_result_t serialization_result = view ? result.index.view(path) : result.index.load(path); + if (!serialization_result) + return result.failed(std::move(serialization_result.error)); return result; } @@ -537,19 +690,23 @@ class index_dense_gt { std::size_t connectivity() const { return typed_->connectivity(); } std::size_t size() const { return typed_->size() - free_keys_.size(); } std::size_t capacity() const { return typed_->capacity(); } - std::size_t max_level() const noexcept { return typed_->max_level(); } + std::size_t max_level() const { return typed_->max_level(); } index_dense_config_t const& config() const { return config_; } index_limits_t const& limits() const { return typed_->limits(); } bool multi() const { return config_.multi; } + std::size_t currently_available_threads() const { + std::unique_lock available_threads_lock(available_threads_mutex_); + return available_threads_.size(); + } // The metric and its properties metric_t const& metric() const { return metric_; } void change_metric(metric_t metric) { metric_ = std::move(metric); } - scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } - std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } - std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } - std::size_t dimensions() const noexcept { return metric_.dimensions(); } + scalar_kind_t scalar_kind() const { return metric_.scalar_kind(); } + std::size_t bytes_per_vector() const { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const { return metric_.scalar_words(); } + std::size_t dimensions() const { return metric_.dimensions(); } // Fetching and changing search criteria std::size_t expansion_add() const { return config_.expansion_add; } @@ -559,8 +716,6 @@ class index_dense_gt { member_citerator_t cbegin() const { return typed_->cbegin(); } member_citerator_t cend() const { return typed_->cend(); } - member_citerator_t begin() const { return typed_->begin(); } - member_citerator_t end() const { return typed_->end(); } member_iterator_t begin() { return typed_->begin(); } member_iterator_t end() { return typed_->end(); } @@ -598,41 +753,41 @@ class index_dense_gt { }; // clang-format off - add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_b1x8); } - add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_i8); } - add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f16); } - add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); } - add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); } - - search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8); } - search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8); } - search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16); } - search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32); } - search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64); } - - template search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_b1x8); } - template search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_i8); } - template search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f16); } - template search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f32); } - template search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f64); } - - std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); } - std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); } - std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f16); } - std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f32); } - std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f64); } - - cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_b1x8); } - cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_i8); } - cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f16); } - cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f32); } - cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f64); } - - aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_b1x8); } - aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_i8); } - aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f16); } - aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f32); } - aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f64); } + add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.b1x8); } + add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.i8); } + add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f16); } + add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f32); } + add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f64); } + + search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.b1x8); } + search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.i8); } + search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f16); } + search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f32); } + search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f64); } + + template search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.b1x8); } + template search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.i8); } + template search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f16); } + template search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f32); } + template search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f64); } + + std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.b1x8); } + std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.i8); } + std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f16); } + std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f32); } + std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f64); } + + cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.b1x8); } + cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.i8); } + cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f16); } + cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f32); } + cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f64); } + + aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.b1x8); } + aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.i8); } + aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f16); } + aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f32); } + aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f64); } // clang-format on /** @@ -641,6 +796,7 @@ class index_dense_gt { * exporting the mean, maximum, and minimum values. */ aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled!"); shared_lock_t lock(slot_lookup_mutex_); aggregated_distances_t result; if (!multi()) { @@ -714,9 +870,8 @@ class index_dense_gt { cluster_config.thread = lock.thread_id; cluster_config.expansion = config_.expansion_search; metric_proxy_t metric{*this}; - auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { - return member.key != free_key_; - }; + vector_key_t free_key_copy = free_key_; + auto allow = [free_key_copy](member_cref_t const& member) noexcept { return member.key != free_key_copy; }; // Find the closest cluster for any vector under that key. while (key_range.first != key_range.second) { @@ -736,16 +891,54 @@ class index_dense_gt { /** * @brief Reserves memory for the index and the keyed lookup. * @return `true` if the memory reservation was successful, `false` otherwise. + * + * ! No update or search operations should be running during this operation. */ - bool reserve(index_limits_t limits) { - { + bool try_reserve(index_limits_t limits) { + + // The slot lookup system will generally prefer power-of-two sizes. + if (config_.enable_key_lookups) { unique_lock_t lock(slot_lookup_mutex_); - slot_lookup_.reserve(limits.members); - vectors_lookup_.resize(limits.members); + if (!slot_lookup_.try_reserve(limits.members)) + return false; + limits.members = slot_lookup_.capacity(); } + + // Once the `slot_lookup_` grows, let's use its capacity as the new + // target for the `vectors_lookup_` to synchronize allocations and + // expensive index re-organizations. + if (limits.members != vectors_lookup_.size()) { + vectors_lookup_t new_vectors_lookup(limits.members); + if (!new_vectors_lookup) + return false; + if (vectors_lookup_.size() > 0) + std::memcpy(new_vectors_lookup.data(), vectors_lookup_.data(), + vectors_lookup_.size() * sizeof(byte_t*)); + vectors_lookup_ = std::move(new_vectors_lookup); + } + + // During reserve, no insertions may be happening, so we can safely overwrite the whole collection. + std::unique_lock available_threads_lock(available_threads_mutex_); + available_threads_.clear(); + if (!available_threads_.reserve(limits.threads())) + return false; + for (std::size_t i = 0; i < limits.threads(); i++) + available_threads_.push(i); + + // Allocate a buffer for the casted vectors. + cast_buffer_t cast_buffer(limits.threads() * metric_.bytes_per_vector()); + if (!cast_buffer) + return false; + cast_buffer_ = std::move(cast_buffer); + return typed_->reserve(limits); } + void reserve(index_limits_t limits) { + if (!try_reserve(limits)) + __usearch_raise_runtime_error("failed to reserve memory"); + } + /** * @brief Erases all the vectors from the index. * @@ -758,7 +951,7 @@ class index_dense_gt { std::unique_lock free_lock(free_keys_mutex_); typed_->clear(); slot_lookup_.clear(); - vectors_lookup_.clear(); + vectors_lookup_.reset(); free_keys_.clear(); vectors_tape_allocator_.reset(); } @@ -771,19 +964,18 @@ class index_dense_gt { * If the index is memory-mapped - releases the mapping and the descriptor. */ void reset() { - unique_lock_t lookup_lock(slot_lookup_mutex_); + unique_lock_t lookup_lock(slot_lookup_mutex_); std::unique_lock free_lock(free_keys_mutex_); std::unique_lock available_threads_lock(available_threads_mutex_); - typed_->reset(); + + if (typed_) + typed_->reset(); slot_lookup_.clear(); - vectors_lookup_.clear(); + vectors_lookup_.reset(); free_keys_.clear(); vectors_tape_allocator_.reset(); - - // Reset the thread IDs. - available_threads_.resize(std::thread::hardware_concurrency()); - std::iota(available_threads_.begin(), available_threads_.end(), 0ul); + available_threads_.reset(); } /** @@ -862,7 +1054,7 @@ class index_dense_gt { /** * @brief Estimate the binary length (in bytes) of the serialized index. */ - std::size_t serialized_length(serialization_config_t config = {}) const noexcept { + std::size_t serialized_length(serialization_config_t config = {}) const { std::size_t dimensions_length = 0; std::size_t matrix_length = 0; if (!config.exclude_vectors) { @@ -884,6 +1076,7 @@ class index_dense_gt { progress_at&& progress = {}) { // Discard all previous memory allocations of `vectors_tape_allocator_` + index_limits_t old_limits = typed_ ? typed_->limits() : index_limits_t{}; reset(); // Infer the new index size @@ -908,7 +1101,9 @@ class index_dense_gt { matrix_cols = dimensions[1]; } // Load the vectors one after another - vectors_lookup_.resize(matrix_rows); + vectors_lookup_ = vectors_lookup_t(matrix_rows); + if (!vectors_lookup_) + return result.failed("Failed to allocate memory to address vectors"); for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); if (!input(vector, matrix_cols)) @@ -927,6 +1122,9 @@ class index_dense_gt { if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) return result.failed("Magic header mismatch - the file isn't an index"); + // fix pre-2.10 headers + fix_pre_2_10_metadata(head); + // Validate the software version if (head.version_major != USEARCH_VERSION_MAJOR) return result.failed("File format may be different, please rebuild"); @@ -939,16 +1137,38 @@ class index_dense_gt { config_.multi = head.multi; metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); - cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); - casts_ = make_casts_(head.kind_scalar); + cast_buffer_ = cast_buffer_t(available_threads_.size() * metric_.bytes_per_vector()); + if (!cast_buffer_) + return result.failed("Failed to allocate memory for the casts"); + casts_ = casts_punned_t::make(head.kind_scalar); } // Pull the actual proximity graph + if (!typed_) { + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Failed to allocate memory for the index"); + new (raw) index_t(config_); + typed_ = raw; + } result = typed_->load_from_stream(std::forward(input), std::forward(progress)); if (!result) return result; if (typed_->size() != static_cast(matrix_rows)) return result.failed("Index size and the number of vectors doesn't match"); + old_limits.members = static_cast(matrix_rows); + if (!typed_->try_reserve(old_limits)) + return result.failed("Failed to reserve memory for the index"); + + // After the index is loaded, we may have to resize the `available_threads_` to + // match the limits of the underlying engine. + available_threads_t available_threads; + std::size_t max_threads = old_limits.threads(); + if (!available_threads.reserve(max_threads)) + return result.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + available_threads_ = std::move(available_threads); reindex_keys_(); return result; @@ -966,6 +1186,7 @@ class index_dense_gt { progress_at&& progress = {}) { // Discard all previous memory allocations of `vectors_tape_allocator_` + index_limits_t old_limits = typed_ ? typed_->limits() : index_limits_t{}; reset(); serialization_result_t result = file.open_if_not(); @@ -1013,6 +1234,9 @@ class index_dense_gt { if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) return result.failed("Magic header mismatch - the file isn't an index"); + // fix pre-2.10 headers + fix_pre_2_10_metadata(head); + // Validate the software version if (head.version_major != USEARCH_VERSION_MAJOR) return result.failed("File format may be different, please rebuild"); @@ -1025,24 +1249,48 @@ class index_dense_gt { config_.multi = head.multi; metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); - cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); - casts_ = make_casts_(head.kind_scalar); + cast_buffer_ = cast_buffer_t(available_threads_.size() * metric_.bytes_per_vector()); + if (!cast_buffer_) + return result.failed("Failed to allocate memory for the casts"); + casts_ = casts_punned_t::make(head.kind_scalar); offset += sizeof(buffer); } // Pull the actual proximity graph + if (!typed_) { + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Failed to allocate memory for the index"); + new (raw) index_t(config_); + typed_ = raw; + } result = typed_->view(std::move(file), offset, std::forward(progress)); if (!result) return result; if (typed_->size() != static_cast(matrix_rows)) return result.failed("Index size and the number of vectors doesn't match"); + old_limits.members = static_cast(matrix_rows); + if (!typed_->try_reserve(old_limits)) + return result.failed("Failed to reserve memory for the index"); // Address the vectors - vectors_lookup_.resize(matrix_rows); + vectors_lookup_ = vectors_lookup_t(matrix_rows); + if (!vectors_lookup_) + return result.failed("Failed to allocate memory to address vectors"); if (!config.exclude_vectors) for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; + // After the index is loaded, we may have to resize the `available_threads_` to + // match the limits of the underlying engine. + available_threads_t available_threads; + std::size_t max_threads = old_limits.threads(); + if (!available_threads.reserve(max_threads)) + return result.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + available_threads_ = std::move(available_threads); + reindex_keys_(); return result; } @@ -1175,6 +1423,7 @@ class index_dense_gt { * @return `true` if the key is present in the index, `false` otherwise. */ bool contains(vector_key_t key) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); shared_lock_t lock(slot_lookup_mutex_); return slot_lookup_.contains(key_and_slot_t::any_slot(key)); } @@ -1184,6 +1433,7 @@ class index_dense_gt { * @return Zero if nothing is found, a positive integer otherwise. */ std::size_t count(vector_key_t key) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); shared_lock_t lock(slot_lookup_mutex_); return slot_lookup_.count(key_and_slot_t::any_slot(key)); } @@ -1208,6 +1458,7 @@ class index_dense_gt { * If an error occurred during the removal operation, `result.error` will contain an error message. */ labeling_result_t remove(vector_key_t key) { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); labeling_result_t result; unique_lock_t lookup_lock(slot_lookup_mutex_); @@ -1218,7 +1469,8 @@ class index_dense_gt { // Grow the removed entries ring, if needed std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); std::unique_lock free_lock(free_keys_mutex_); - if (!free_keys_.reserve(free_keys_.size() + matching_count)) + std::size_t free_count_old = free_keys_.size(); + if (!free_keys_.reserve(free_count_old + matching_count)) return result.failed("Can't allocate memory for a free-list"); // A removed entry would be: @@ -1232,6 +1484,7 @@ class index_dense_gt { } slot_lookup_.erase(key); result.completed = matching_count; + usearch_assert_m(free_keys_.size() == free_count_old + matching_count, "Free keys count mismatch"); return result; } @@ -1246,6 +1499,7 @@ class index_dense_gt { */ template labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); labeling_result_t result; unique_lock_t lookup_lock(slot_lookup_mutex_); @@ -1332,17 +1586,6 @@ class index_dense_gt { }); } - struct copy_result_t { - index_dense_gt index; - error_t error; - - explicit operator bool() const noexcept { return !error; } - copy_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - /** * @brief Copies the ::index_dense_gt @b with all the data in it. * @param config The copy configuration (optional). @@ -1365,19 +1608,22 @@ class index_dense_gt { copy.free_keys_.push(free_keys_[i]); // Allocate buffers and move the vectors themselves - if (!config.force_vector_copy && copy.config_.exclude_vectors) - copy.vectors_lookup_ = vectors_lookup_; - else { - copy.vectors_lookup_.resize(vectors_lookup_.size()); - for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + copy.vectors_lookup_ = vectors_lookup_t(vectors_lookup_.size()); + if (!copy.vectors_lookup_) + return result.failed("Out of memory!"); + if (!config.force_vector_copy && copy.config_.exclude_vectors) { + std::memcpy(copy.vectors_lookup_.data(), vectors_lookup_.data(), vectors_lookup_.size() * sizeof(byte_t*)); + } else { + std::size_t slots_count = typed_result.index.size(); + for (std::size_t slot = 0; slot != slots_count; ++slot) copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); - if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.end(), nullptr)) + if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.begin() + slots_count, nullptr)) return result.failed("Out of memory!"); - for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + for (std::size_t slot = 0; slot != slots_count; ++slot) std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); } - copy.slot_lookup_ = slot_lookup_; + copy.slot_lookup_ = slot_lookup_; // TODO: Handle out of memory *copy.typed_ = std::move(typed_result.index); return result; } @@ -1387,22 +1633,34 @@ class index_dense_gt { * @return A similarly configured ::index_dense_gt instance. */ copy_result_t fork() const { + + cast_buffer_t cast_buffer(cast_buffer_.size()); + if (!cast_buffer) + return state_result_t{}.failed("Failed to allocate memory for the casts!"); + available_threads_t available_threads; + std::size_t max_threads = limits().threads(); + if (!available_threads.reserve(max_threads)) + return state_result_t{}.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return state_result_t{}.failed("Failed to allocate memory for the index!"); + copy_result_t result; index_dense_gt& other = result.index; - + index_limits_t other_limits = limits(); + other_limits.members = 0; other.config_ = config_; - other.cast_buffer_ = cast_buffer_; + other.cast_buffer_ = std::move(cast_buffer); other.casts_ = casts_; other.metric_ = metric_; - other.available_threads_ = available_threads_; + other.available_threads_ = std::move(available_threads); other.free_key_ = free_key_; - index_t* raw = index_allocator_t{}.allocate(1); - if (!raw) - return result.failed("Can't allocate the index"); - new (raw) index_t(config()); + raw->try_reserve(other_limits); other.typed_ = raw; return result; } @@ -1461,7 +1719,10 @@ class index_dense_gt { compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { compaction_result_t result; - std::vector new_vectors_lookup(vectors_lookup_.size()); + vectors_lookup_t new_vectors_lookup(vectors_lookup_.size()); + if (!new_vectors_lookup) + return result.failed("Out of memory!"); + vectors_tape_allocator_t new_vectors_allocator; auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { @@ -1698,40 +1959,30 @@ class index_dense_gt { } private: - struct thread_lock_t { - index_dense_gt const& parent; - std::size_t thread_id; - bool engaged; - - ~thread_lock_t() { - if (engaged) - parent.thread_unlock_(thread_id); - } - }; - - thread_lock_t thread_lock_(std::size_t thread_id) const { + thread_lock_t thread_lock_(std::size_t thread_id) const usearch_noexcept_m { if (thread_id != any_thread()) return {*this, thread_id, false}; available_threads_mutex_.lock(); - thread_id = available_threads_.back(); - available_threads_.pop_back(); + usearch_assert_m(available_threads_.size(), "No available threads to lock"); + available_threads_.try_pop(thread_id); available_threads_mutex_.unlock(); return {*this, thread_id, true}; } - void thread_unlock_(std::size_t thread_id) const { + void thread_unlock_(std::size_t thread_id) const usearch_noexcept_m { available_threads_mutex_.lock(); - available_threads_.push_back(thread_id); + usearch_assert_m(available_threads_.size() < available_threads_.capacity(), "Too many threads unlocked"); + available_threads_.push(thread_id); available_threads_mutex_.unlock(); } template add_result_t add_( // vector_key_t key, scalar_at const* vector, // - std::size_t thread, bool force_vector_copy, cast_t const& cast) { + std::size_t thread, bool force_vector_copy, cast_punned_t const& cast) { - if (!multi() && contains(key)) + if (!multi() && config().enable_key_lookups && contains(key)) return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); // Cast the vector, if needed for compatibility with `metric_` @@ -1755,8 +2006,10 @@ class index_dense_gt { // Perform the insertion or the update bool reuse_node = free_slot != default_free_value(); auto on_success = [&](member_ref_t member) { - unique_lock_t slot_lock(slot_lookup_mutex_); - slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + if (config_.enable_key_lookups) { + unique_lock_t slot_lock(slot_lookup_mutex_); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + } if (copy_vector) { if (!reuse_node) vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); @@ -1777,7 +2030,7 @@ class index_dense_gt { template search_result_t search_(scalar_at const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread, - bool exact, cast_t const& cast) const { + bool exact, cast_punned_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); @@ -1794,23 +2047,26 @@ class index_dense_gt { search_config.expansion = config_.expansion_search; search_config.exact = exact; + vector_key_t free_key_copy = free_key_; if (std::is_same::type, dummy_predicate_t>::value) { - auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { - return member.key != free_key_; + auto allow = [free_key_copy](member_cref_t const& member) noexcept { + return (vector_key_t)member.key != free_key_copy; }; - return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + auto typed_result = typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + return search_result_t{std::move(typed_result), std::move(lock)}; } else { - auto allow = [free_key_ = this->free_key_, &predicate](member_cref_t const& member) noexcept { - return member.key != free_key_ && predicate(member.key); + auto allow = [free_key_copy, &predicate](member_cref_t const& member) noexcept { + return (vector_key_t)member.key != free_key_copy && predicate(member.key); }; - return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + auto typed_result = typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + return search_result_t{std::move(typed_result), std::move(lock)}; } } template cluster_result_t cluster_( // scalar_at const* vector, std::size_t level, // - std::size_t thread, cast_t const& cast) const { + std::size_t thread, cast_punned_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); @@ -1826,16 +2082,15 @@ class index_dense_gt { cluster_config.thread = lock.thread_id; cluster_config.expansion = config_.expansion_search; - auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { - return member.key != free_key_; - }; + vector_key_t free_key_copy = free_key_; + auto allow = [free_key_copy](member_cref_t const& member) noexcept { return member.key != free_key_copy; }; return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); } template aggregated_distances_t distance_between_( // vector_key_t key, scalar_at const* vector, // - std::size_t thread, cast_t const& cast) const { + std::size_t thread, cast_punned_t const& cast) const { // Cast the vector, if needed for compatibility with `metric_` thread_lock_t lock = thread_lock_(thread); @@ -1848,6 +2103,7 @@ class index_dense_gt { } // Check if such `key` is even present. + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled!"); shared_lock_t slots_lock(slot_lookup_mutex_); auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); aggregated_distances_t result; @@ -1884,7 +2140,8 @@ class index_dense_gt { std::size_t count_total = typed_->size(); std::size_t count_removed = 0; for (std::size_t i = 0; i != count_total; ++i) { - member_cref_t member = typed_->at(i); + auto member_slot = static_cast(i); + member_cref_t member = typed_->at(member_slot); count_removed += member.key == free_key_; } @@ -1900,16 +2157,18 @@ class index_dense_gt { free_keys_.clear(); free_keys_.reserve(count_removed); for (std::size_t i = 0; i != typed_->size(); ++i) { - member_cref_t member = typed_->at(i); + auto member_slot = static_cast(i); + member_cref_t member = typed_->at(member_slot); if (member.key == free_key_) - free_keys_.push(static_cast(i)); + free_keys_.push(member_slot); else if (config_.enable_key_lookups) - slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), member_slot}); } } template - std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const { + std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, + cast_punned_t const& cast) const { if (!multi()) { compressed_slot_t slot; @@ -1944,35 +2203,6 @@ class index_dense_gt { return count_exported; } } - - template static casts_t make_casts_() { - casts_t result; - - result.from_b1x8 = cast_gt{}; - result.from_i8 = cast_gt{}; - result.from_f16 = cast_gt{}; - result.from_f32 = cast_gt{}; - result.from_f64 = cast_gt{}; - - result.to_b1x8 = cast_gt{}; - result.to_i8 = cast_gt{}; - result.to_f16 = cast_gt{}; - result.to_f32 = cast_gt{}; - result.to_f64 = cast_gt{}; - - return result; - } - - static casts_t make_casts_(scalar_kind_t scalar_kind) { - switch (scalar_kind) { - case scalar_kind_t::f64_k: return make_casts_(); - case scalar_kind_t::f32_k: return make_casts_(); - case scalar_kind_t::f16_k: return make_casts_(); - case scalar_kind_t::i8_k: return make_casts_(); - case scalar_kind_t::b1x8_k: return make_casts_(); - default: return {}; - } - } }; using index_dense_t = index_dense_gt<>; @@ -2008,7 +2238,7 @@ static join_result_t join( // man_to_woman_at&& man_to_woman = man_to_woman_at{}, // woman_to_man_at&& woman_to_man = woman_to_man_at{}, // executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}) noexcept { + progress_at&& progress = progress_at{}) { return men.join( // women, config, // @@ -2020,3 +2250,9 @@ static join_result_t join( // } // namespace usearch } // namespace unum + +// This file is part of the usearch inline third-party dependency of YugabyteDB. +// Git repo: https://github.com/unum-cloud/usearch +// Git commit: 240fe9c298100f9e37a2d7377b1595be6ba1f412 +// +// See also src/inline-thirdparty/README.md. diff --git a/src/inline-thirdparty/usearch/usearch/index_plugins.hpp b/src/inline-thirdparty/usearch/usearch/index_plugins.hpp index a5539a3a6b04..2a6839c4fa3b 100644 --- a/src/inline-thirdparty/usearch/usearch/index_plugins.hpp +++ b/src/inline-thirdparty/usearch/usearch/index_plugins.hpp @@ -3,13 +3,9 @@ #include // `_Float16` #include // `aligned_alloc` +#include // `std::atomic` #include // `std::strncmp` -#include // `std::iota` #include // `std::thread` -#include // `std::vector` - -#include // `std::atomic` -#include // `std::thread` #include // `expected_gt` and macros @@ -46,13 +42,24 @@ #if USEARCH_USE_SIMSIMD // Propagate the `f16` settings -#define SIMSIMD_NATIVE_F16 !USEARCH_USE_FP16LIB -#define SIMSIMD_DYNAMIC_DISPATCH 0 +#if defined(USEARCH_CAN_COMPILE_FP16) || defined(USEARCH_CAN_COMPILE_FLOAT16) +#define SIMSIMD_NATIVE_F16 1 +#endif +// Propagate the `bf16` settings +#if defined(USEARCH_CAN_COMPILE_BF16) || defined(USEARCH_CAN_COMPILE_BFLOAT16) +#define SIMSIMD_NATIVE_BF16 1 +#endif // No problem, if some of the functions are unused or undefined #pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" +#pragma GCC diagnostic ignored "-Wunused" #pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" #pragma warning(push) -#pragma warning(disable : 4101) +#pragma warning(disable : 4101) // "Unused variables" +#pragma warning(disable : 4068) // "Unknown pragmas", when MSVC tries to read GCC pragmas #include #pragma warning(pop) #pragma GCC diagnostic pop @@ -69,19 +76,10 @@ struct uuid_t { }; class f16_bits_t; -class i8_converted_t; +class bf16_bits_t; -#if !USEARCH_USE_FP16LIB -#if defined(USEARCH_DEFINED_ARM) -using f16_native_t = __fp16; -#else -using f16_native_t = _Float16; -#endif -using f16_t = f16_native_t; -#else -using f16_native_t = void; using f16_t = f16_bits_t; -#endif +using bf16_t = bf16_bits_t; using f64_t = double; using f32_t = float; @@ -96,6 +94,9 @@ using i32_t = std::int32_t; using i16_t = std::int16_t; using i8_t = std::int8_t; +/** + * @brief Enumerates the most commonly used distance metrics, mostly for dense vector representations. + */ enum class metric_kind_t : std::uint8_t { unknown_k = 0, // Classics: @@ -108,19 +109,26 @@ enum class metric_kind_t : std::uint8_t { haversine_k = 'h', divergence_k = 'd', - // Sets: - jaccard_k = 'j', + // Dense Sets: hamming_k = 'b', tanimoto_k = 't', sorensen_k = 's', + + // Sparse Sets: + jaccard_k = 'j', }; +/** + * @brief Enumerates the most commonly used scalar types, mostly for dense vector representations. + * Doesn't include logical types, like complex numbers or quaternions. + */ enum class scalar_kind_t : std::uint8_t { unknown_k = 0, // Custom: b1x8_k = 1, u40_k = 2, uuid_k = 3, + bf16_k = 4, // Common: f64_k = 10, f32_k = 11, @@ -137,12 +145,9 @@ enum class scalar_kind_t : std::uint8_t { i8_k = 23, }; -enum class prefetching_kind_t { - none_k, - cpu_k, - io_uring_k, -}; - +/** + * @brief Maps a scalar type to its corresponding scalar_kind_t enumeration value. + */ template scalar_kind_t scalar_kind() noexcept { if (std::is_same()) return scalar_kind_t::b1x8_k; @@ -156,6 +161,8 @@ template scalar_kind_t scalar_kind() noexcept { return scalar_kind_t::f32_k; if (std::is_same()) return scalar_kind_t::f16_k; + if (std::is_same()) + return scalar_kind_t::bf16_k; if (std::is_same()) return scalar_kind_t::i8_k; if (std::is_same()) @@ -177,55 +184,119 @@ template scalar_kind_t scalar_kind() noexcept { return scalar_kind_t::unknown_k; } +/** + * @brief Converts an angle from degrees to radians. + */ template at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); } +/** + * @brief Readability helper to compute the square of a given value. + */ template at square(at value) noexcept { return value * value; } +/** + * @brief Clamps a value between a lower and upper bound using a custom comparator. Similar to `std::clamp`. + * https://en.cppreference.com/w/cpp/algorithm/clamp + */ template inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { return comp(v, lo) ? lo : comp(hi, v) ? hi : v; } + +/** + * @brief Clamps a value between a lower and upper bound. Similar to `std::clamp`. + * https://en.cppreference.com/w/cpp/algorithm/clamp + */ template inline at clamp(at v, at lo, at hi) noexcept { return usearch::clamp(v, lo, hi, std::less{}); } -inline bool str_equals(char const* begin, std::size_t len, char const* other_begin) noexcept { - std::size_t other_len = std::strlen(other_begin); - return len == other_len && std::strncmp(begin, other_begin, len) == 0; +/** + * @brief Compares two strings for equality, given a length for the first string. + */ +inline bool str_equals(char const* first_begin, std::size_t first_len, char const* second_begin) noexcept { + std::size_t second_len = std::strlen(second_begin); + return first_len == second_len && std::strncmp(first_begin, second_begin, first_len) == 0; } +/** + * @brief Returns the number of bits required to represent a scalar type. + */ inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { + case scalar_kind_t::uuid_k: return 128; + case scalar_kind_t::u40_k: return 40; + case scalar_kind_t::bf16_k: return 16; + case scalar_kind_t::b1x8_k: return 1; + case scalar_kind_t::u64_k: return 64; + case scalar_kind_t::i64_k: return 64; case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::u32_k: return 32; + case scalar_kind_t::i32_k: return 32; case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::u16_k: return 16; + case scalar_kind_t::i16_k: return 16; case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::u8_k: return 8; case scalar_kind_t::i8_k: return 8; - case scalar_kind_t::b1x8_k: return 1; + case scalar_kind_t::f8_k: return 8; default: return 0; } } +/** + * @brief Returns the number of bits in a scalar word for a given scalar type. + * Equivalent to `bits_per_scalar` for types that are not bit-packed. + */ inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { + case scalar_kind_t::uuid_k: return 128; + case scalar_kind_t::u40_k: return 40; + case scalar_kind_t::bf16_k: return 16; + case scalar_kind_t::b1x8_k: return 8; + case scalar_kind_t::u64_k: return 64; + case scalar_kind_t::i64_k: return 64; case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::u32_k: return 32; + case scalar_kind_t::i32_k: return 32; case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::u16_k: return 16; + case scalar_kind_t::i16_k: return 16; case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::u8_k: return 8; case scalar_kind_t::i8_k: return 8; - case scalar_kind_t::b1x8_k: return 8; + case scalar_kind_t::f8_k: return 8; default: return 0; } } +/** + * @brief Returns the string name of a given scalar type. + */ inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept { switch (scalar_kind) { + case scalar_kind_t::uuid_k: return "uuid"; + case scalar_kind_t::u40_k: return "u40"; + case scalar_kind_t::bf16_k: return "bf16"; + case scalar_kind_t::b1x8_k: return "b1x8"; + case scalar_kind_t::u64_k: return "u64"; + case scalar_kind_t::i64_k: return "i64"; + case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::u32_k: return "u32"; + case scalar_kind_t::i32_k: return "i32"; case scalar_kind_t::f32_k: return "f32"; + case scalar_kind_t::u16_k: return "u16"; + case scalar_kind_t::i16_k: return "i16"; case scalar_kind_t::f16_k: return "f16"; - case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::u8_k: return "u8"; case scalar_kind_t::i8_k: return "i8"; - case scalar_kind_t::b1x8_k: return "b1x8"; + case scalar_kind_t::f8_k: return "f8"; default: return ""; } } +/** + * @brief Returns the string name of a given distance metric. + */ inline char const* metric_kind_name(metric_kind_t metric) noexcept { switch (metric) { case metric_kind_t::unknown_k: return "unknown"; @@ -239,9 +310,13 @@ inline char const* metric_kind_name(metric_kind_t metric) noexcept { case metric_kind_t::hamming_k: return "hamming"; case metric_kind_t::tanimoto_k: return "tanimoto"; case metric_kind_t::sorensen_k: return "sorensen"; + default: return ""; } - return ""; } + +/** + * @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value. + */ inline expected_gt scalar_kind_from_name(char const* name, std::size_t len) { expected_gt parsed; if (str_equals(name, len, "f32")) @@ -250,17 +325,27 @@ inline expected_gt scalar_kind_from_name(char const* name, std::s parsed.result = scalar_kind_t::f64_k; else if (str_equals(name, len, "f16")) parsed.result = scalar_kind_t::f16_k; + else if (str_equals(name, len, "bf16")) + parsed.result = scalar_kind_t::bf16_k; else if (str_equals(name, len, "i8")) parsed.result = scalar_kind_t::i8_k; + else if (str_equals(name, len, "b1")) + parsed.result = scalar_kind_t::b1x8_k; else - parsed.failed("Unknown type, choose: f32, f16, f64, i8"); + parsed.failed("Unknown type, choose: f64, f32, f16, bf16, i8, b1"); return parsed; } +/** + * @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value. + */ inline expected_gt scalar_kind_from_name(char const* name) { return scalar_kind_from_name(name, std::strlen(name)); } +/** + * @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value. + */ inline expected_gt metric_from_name(char const* name, std::size_t len) { expected_gt parsed; if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { @@ -287,28 +372,85 @@ inline expected_gt metric_from_name(char const* name, std::size_t return parsed; } +/** + * @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value. + */ inline expected_gt metric_from_name(char const* name) { return metric_from_name(name, std::strlen(name)); } +/** + * @brief Convenience function to upcast a half-precision floating point number to a single-precision one. + */ inline float f16_to_f32(std::uint16_t u16) noexcept { -#if !USEARCH_USE_FP16LIB - f16_native_t f16; +#if USEARCH_USE_FP16LIB + return fp16_ieee_to_fp32_value(u16); +#elif USEARCH_USE_SIMSIMD + return simsimd_uncompress_f16((simsimd_f16_t const*)&u16); +#else +#warning "It's recommended to use SimSIMD and fp16lib for half-precision numerics" + _Float16 f16; std::memcpy(&f16, &u16, sizeof(std::uint16_t)); return float(f16); -#else - return fp16_ieee_to_fp32_value(u16); #endif } +/** + * @brief Convenience function to downcast a single-precision floating point number to a half-precision one. + */ inline std::uint16_t f32_to_f16(float f32) noexcept { -#if !USEARCH_USE_FP16LIB - f16_native_t f16 = f16_native_t(f32); +#if USEARCH_USE_FP16LIB + return fp16_ieee_from_fp32_value(f32); +#elif USEARCH_USE_SIMSIMD + std::uint16_t result; + simsimd_compress_f16(f32, (simsimd_f16_t*)&result); + return result; +#else +#warning "It's recommended to use SimSIMD and fp16lib for half-precision numerics" + _Float16 f16 = _Float16(f32); std::uint16_t u16; std::memcpy(&u16, &f16, sizeof(std::uint16_t)); return u16; +#endif +} + +/** + * @brief Convenience function to upcast a brain-floating point number to a single-precision one. + * https://github.com/ashvardanian/SimSIMD/blob/ff51434d90c66f916e94ff05b24530b127aa4cff/include/simsimd/types.h#L395-L410 + */ +inline float bf16_to_f32(std::uint16_t u16) noexcept { +#if USEARCH_USE_SIMSIMD + return simsimd_uncompress_bf16((simsimd_bf16_t const*)&u16); #else - return fp16_ieee_from_fp32_value(f32); + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + union float_or_unsigned_int_t result_union; + result_union.i = u16 << 16; // Zero extends the mantissa + return result_union.f; +#endif +} + +/** + * @brief Convenience function to downcast a single-precision floating point number to a brain-floating point one. + * https://github.com/ashvardanian/SimSIMD/blob/ff51434d90c66f916e94ff05b24530b127aa4cff/include/simsimd/types.h#L412-L425 + */ +inline std::uint16_t f32_to_bf16(float f32) noexcept { +#if USEARCH_USE_SIMSIMD + std::uint16_t result; + simsimd_compress_bf16(f32, (simsimd_bf16_t*)&result); + return result; +#else + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + union float_or_unsigned_int_t value; + value.f = f32; + value.i >>= 16; + value.i &= 0xFFFF; + return (unsigned short)value.i; #endif } @@ -330,23 +472,25 @@ class f16_bits_t { inline operator float() const noexcept { return f16_to_f32(uint16_); } inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } - inline f16_bits_t(i8_converted_t) noexcept; + inline f16_bits_t(std::int8_t v) noexcept : uint16_(f32_to_f16(v)) {} inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) {} inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + inline bool operator<(f16_bits_t const& other) const noexcept { return float(*this) < float(other); } + inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } - inline f16_bits_t operator+(float other) const noexcept { return {float(*this) + other}; } - inline f16_bits_t operator-(float other) const noexcept { return {float(*this) - other}; } - inline f16_bits_t operator*(float other) const noexcept { return {float(*this) * other}; } - inline f16_bits_t operator/(float other) const noexcept { return {float(*this) / other}; } - inline f16_bits_t operator+(double other) const noexcept { return {float(*this) + other}; } - inline f16_bits_t operator-(double other) const noexcept { return {float(*this) - other}; } - inline f16_bits_t operator*(double other) const noexcept { return {float(*this) * other}; } - inline f16_bits_t operator/(double other) const noexcept { return {float(*this) / other}; } + inline float operator+(float other) const noexcept { return float(*this) + other; } + inline float operator-(float other) const noexcept { return float(*this) - other; } + inline float operator*(float other) const noexcept { return float(*this) * other; } + inline float operator/(float other) const noexcept { return float(*this) / other; } + inline double operator+(double other) const noexcept { return float(*this) + other; } + inline double operator-(double other) const noexcept { return float(*this) - other; } + inline double operator*(double other) const noexcept { return float(*this) * other; } + inline double operator/(double other) const noexcept { return float(*this) / other; } inline f16_bits_t& operator+=(float v) noexcept { uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); @@ -369,6 +513,65 @@ class f16_bits_t { } }; +/** + * @brief Numeric type for brain-floating point half-precision floating point. + * If hardware support isn't available, falls back to a hardware + * agnostic in-software implementation. + */ +class bf16_bits_t { + std::uint16_t uint16_{}; + + public: + inline bf16_bits_t() noexcept : uint16_(0) {} + inline bf16_bits_t(bf16_bits_t&&) = default; + inline bf16_bits_t& operator=(bf16_bits_t&&) = default; + inline bf16_bits_t(bf16_bits_t const&) = default; + inline bf16_bits_t& operator=(bf16_bits_t const&) = default; + + inline operator float() const noexcept { return bf16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return bf16_to_f32(uint16_) > 0.5f; } + + inline bf16_bits_t(std::int8_t v) noexcept : uint16_(f32_to_bf16(v)) {} + inline bf16_bits_t(bool v) noexcept : uint16_(f32_to_bf16(v)) {} + inline bf16_bits_t(float v) noexcept : uint16_(f32_to_bf16(v)) {} + inline bf16_bits_t(double v) noexcept : uint16_(f32_to_bf16(static_cast(v))) {} + + inline bool operator<(bf16_bits_t const& other) const noexcept { return float(*this) < float(other); } + + inline bf16_bits_t operator+(bf16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline bf16_bits_t operator-(bf16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline bf16_bits_t operator*(bf16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline bf16_bits_t operator/(bf16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline float operator+(float other) const noexcept { return float(*this) + other; } + inline float operator-(float other) const noexcept { return float(*this) - other; } + inline float operator*(float other) const noexcept { return float(*this) * other; } + inline float operator/(float other) const noexcept { return float(*this) / other; } + inline double operator+(double other) const noexcept { return float(*this) + other; } + inline double operator-(double other) const noexcept { return float(*this) - other; } + inline double operator*(double other) const noexcept { return float(*this) * other; } + inline double operator/(double other) const noexcept { return float(*this) / other; } + + inline bf16_bits_t& operator+=(float v) noexcept { + uint16_ = f32_to_bf16(v + bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator-=(float v) noexcept { + uint16_ = f32_to_bf16(v - bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator*=(float v) noexcept { + uint16_ = f32_to_bf16(v * bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator/=(float v) noexcept { + uint16_ = f32_to_bf16(v / bf16_to_f32(uint16_)); + return *this; + } +}; + /** * @brief An STL-based executor or a "thread-pool" for parallel execution. * Isn't efficient for small batches, as it recreates the threads on every call. @@ -378,14 +581,16 @@ class executor_stl_t { struct jthread_t { std::thread native_; + bool initialized_ = false; jthread_t() = default; jthread_t(jthread_t&&) = default; jthread_t(jthread_t const&) = delete; - template jthread_t(callable_at&& func) : native_([=]() { func(); }) {} + template + jthread_t(callable_at&& func) : native_([=]() { func(); }), initialized_(true) {} ~jthread_t() { - if (native_.joinable()) + if (initialized_ && native_.joinable()) native_.join(); } }; @@ -410,13 +615,13 @@ class executor_stl_t { */ template void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { - std::vector threads_pool; + buffer_gt threads_pool(threads_count_ - 1); // Allocate space for threads minus the main thread std::size_t tasks_per_thread = tasks; std::size_t threads_count = (std::min)(threads_count_, tasks); if (threads_count > 1) { tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { - threads_pool.emplace_back([=]() { + new (&threads_pool[thread_idx - 1]) jthread_t([=]() { for (std::size_t task_idx = thread_idx * tasks_per_thread; task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread); ++task_idx) thread_aware_function(thread_idx, task_idx); @@ -435,14 +640,14 @@ class executor_stl_t { */ template void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { - std::vector threads_pool; + buffer_gt threads_pool(threads_count_ - 1); std::size_t tasks_per_thread = tasks; std::size_t threads_count = (std::min)(threads_count_, tasks); std::atomic_bool stop{false}; if (threads_count > 1) { tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { - threads_pool.emplace_back([=, &stop]() { + new (&threads_pool[thread_idx - 1]) jthread_t([=, &stop]() { for (std::size_t task_idx = thread_idx * tasks_per_thread; task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread) && !stop.load(std::memory_order_relaxed); @@ -467,9 +672,9 @@ class executor_stl_t { void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { if (threads_count_ == 1) return thread_aware_function(0); - std::vector threads_pool; + buffer_gt threads_pool(threads_count_ - 1); for (std::size_t thread_idx = 1; thread_idx < threads_count_; ++thread_idx) - threads_pool.emplace_back([=]() { thread_aware_function(thread_idx); }); + new (&threads_pool[thread_idx - 1]) jthread_t([=]() { thread_aware_function(thread_idx); }); thread_aware_function(0); } }; @@ -557,7 +762,8 @@ using executor_default_t = executor_stl_t; #endif /** - * @brief Uses OS-specific APIs for aligned memory allocations. + * @brief Uses OS-specific APIs for aligned memory allocations. + * Available since C11, but only C++17, so we wrap the C version. */ template // class aligned_allocator_gt { @@ -574,12 +780,18 @@ class aligned_allocator_gt { pointer allocate(size_type length) const { std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); + // Avoid overflow + if (length > length_bytes) + return nullptr; std::size_t alignment = alignment_ak; - // void* result = nullptr; - // int status = posix_memalign(&result, alignment, length_bytes); - // return status == 0 ? (pointer)result : nullptr; #if defined(USEARCH_DEFINED_WINDOWS) return (pointer)_aligned_malloc(length_bytes, alignment); +#elif defined(USEARCH_DEFINED_APPLE) + // Apple Clang keeps complaining that `aligned_alloc` is only available + // with macOS 10.15 and newer, so let's use `posix_memalign` there. + void* result = nullptr; + int status = posix_memalign(&result, alignment, length_bytes); + return status == 0 ? (pointer)result : nullptr; #else return (pointer)aligned_alloc(alignment, length_bytes); #endif @@ -596,6 +808,10 @@ class aligned_allocator_gt { using aligned_allocator_t = aligned_allocator_gt<>; +/** + * @brief A simple RAM-page allocator that uses the OS-specific APIs for memory allocation. + * Shouldn't be used frequently, as system calls are slow. + */ class page_allocator_t { public: static constexpr std::size_t page_size() { return 4096; } @@ -859,7 +1075,7 @@ template class shared_lock_gt { * avoiding unnecessary conversions. */ template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { from_scalar_at const* typed_input = reinterpret_cast(input); to_scalar_at* typed_output = reinterpret_cast(output); auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; @@ -869,29 +1085,34 @@ template struct cast_gt { }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } }; -template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { +template struct cast_to_b1x8_gt { + inline static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { from_scalar_at const* typed_input = reinterpret_cast(input); unsigned char* typed_output = reinterpret_cast(output); + std::memset(typed_output, 0, dim / CHAR_BIT); for (std::size_t i = 0; i != dim; ++i) // Converting from scalar types to boolean isn't trivial and depends on the type. // The most common case is to consider all positive values as `true` and all others as `false`. @@ -908,8 +1129,8 @@ template struct cast_gt { } }; -template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { +template struct cast_from_b1x8_gt { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { unsigned char const* typed_input = reinterpret_cast(input); to_scalar_at* typed_output = reinterpret_cast(output); for (std::size_t i = 0; i != dim; ++i) @@ -920,52 +1141,134 @@ template struct cast_gt { } }; +template struct cast_to_i8_gt { + inline static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + from_scalar_at const* typed_input = reinterpret_cast(input); + std::int8_t* typed_output = reinterpret_cast(output); + // Unlike other casting mechanisms, switching to small range integers is a two step procedure. + // First we want to estimate the magnitude of the vector to scale into [-1.0, 1.0] interval, + // instead of clamping. And then we scale the values into the [-127, 127] range. + // ! This makes an assumption, that the distance metric is dot-product-like, which may not + // ! be true in many cases, so it's recommended to avoid automatic casting from floats to + // ! integers. + double magnitude = 0.0; + for (std::size_t i = 0; i != dim; ++i) + magnitude += (double)typed_input[i] * (double)typed_input[i]; + magnitude = std::sqrt(magnitude); + for (std::size_t i = 0; i != dim; ++i) + typed_output[i] = + static_cast(usearch::clamp(typed_input[i] * 127.0 / magnitude, -127.0, 127.0)); + return true; + } +}; + +template struct cast_from_i8_gt { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + std::int8_t const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + typed_output[i] = static_cast(typed_input[i]) / 127.f; + return true; + } +}; + +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; + +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; + +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; + +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; + +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; + /** - * @brief Numeric type for uniformly-distributed floating point - * values within [-1,1] range, quantized to integers [-100,100]. + * @brief Type-punned array casting function. + * Arguments: input buffer, bytes in input buffer, output buffer. + * Returns `true` if the casting was performed successfully, `false` otherwise. */ -class i8_converted_t { - std::int8_t int8_{}; +using cast_punned_t = bool (*)(byte_t const*, std::size_t, byte_t*); - public: - constexpr static f32_t divisor_k = 100.f; - constexpr static std::int8_t min_k = -100; - constexpr static std::int8_t max_k = 100; - - inline i8_converted_t() noexcept : int8_(0) {} - inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) {} - - inline i8_converted_t(i8_converted_t&&) = default; - inline i8_converted_t& operator=(i8_converted_t&&) = default; - inline i8_converted_t(i8_converted_t const&) = default; - inline i8_converted_t& operator=(i8_converted_t const&) = default; - - inline operator f16_t() const noexcept { return static_cast(f32_t(int8_) / divisor_k); } - inline operator f32_t() const noexcept { return f32_t(int8_) / divisor_k; } - inline operator f64_t() const noexcept { return f64_t(int8_) / divisor_k; } - inline explicit operator bool() const noexcept { return int8_ > (max_k / 2); } - inline explicit operator std::int8_t() const noexcept { return int8_; } - inline explicit operator std::int16_t() const noexcept { return int8_; } - inline explicit operator std::int32_t() const noexcept { return int8_; } - inline explicit operator std::int64_t() const noexcept { return int8_; } - - inline i8_converted_t(f16_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} - inline i8_converted_t(f32_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} - inline i8_converted_t(f64_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} -}; +/** + * @brief A collection of casting functions for typical vector types. + * Covers to/from conversions for boolean, integer, half-precision, + * single-precision, and double-precision scalars. + */ +struct casts_punned_t { + struct group_t { + cast_punned_t b1x8{}; + cast_punned_t i8{}; + cast_punned_t f16{}; + cast_punned_t f32{}; + cast_punned_t f64{}; + + cast_punned_t operator[](scalar_kind_t scalar_kind) const noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return f64; + case scalar_kind_t::f32_k: return f32; + case scalar_kind_t::f16_k: return f16; + case scalar_kind_t::bf16_k: return f16; + case scalar_kind_t::i8_k: return i8; + case scalar_kind_t::b1x8_k: return b1x8; + default: return nullptr; + } + } + + } from, to; + + template static casts_punned_t make() noexcept { + casts_punned_t result; -f16_bits_t::f16_bits_t(i8_converted_t v) noexcept : uint16_(f32_to_f16(v)) {} + result.from.b1x8 = &cast_gt::try_; + result.from.i8 = &cast_gt::try_; + result.from.f16 = &cast_gt::try_; + result.from.f32 = &cast_gt::try_; + result.from.f64 = &cast_gt::try_; -template <> struct cast_gt : public cast_gt {}; -template <> struct cast_gt : public cast_gt {}; -template <> struct cast_gt : public cast_gt {}; + result.to.b1x8 = &cast_gt::try_; + result.to.i8 = &cast_gt::try_; + result.to.f16 = &cast_gt::try_; + result.to.f32 = &cast_gt::try_; + result.to.f64 = &cast_gt::try_; -template <> struct cast_gt : public cast_gt {}; -template <> struct cast_gt : public cast_gt {}; -template <> struct cast_gt : public cast_gt {}; + return result; + } + + static casts_punned_t make(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return casts_punned_t::make(); + case scalar_kind_t::f32_k: return casts_punned_t::make(); + case scalar_kind_t::f16_k: return casts_punned_t::make(); + case scalar_kind_t::bf16_k: return casts_punned_t::make(); + case scalar_kind_t::i8_k: return casts_punned_t::make(); + case scalar_kind_t::b1x8_k: return casts_punned_t::make(); + default: return {}; + } + } +}; + +/* Don't complain if the vectorization of the inner loops fails: + * + * > warning: loop not vectorized: the optimizer was unable to perform the requested transformation; + * > the transformation might be disabled or specified as part of an unsupported transformation ordering + */ +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif /** * @brief Inner (Dot) Product distance. @@ -1151,18 +1454,21 @@ template struct using scalar_t = scalar_at; using result_t = result_at; static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); + static_assert(std::is_floating_point::value, "Jaccard distance returns a fraction"); inline result_t operator()( // scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { - result_t intersection{}; + std::size_t intersection{}; std::size_t i{}; std::size_t j{}; while (i != a_length && j != b_length) { - intersection += a[i] == b[j]; - i += a[i] < b[j]; - j += a[i] >= b[j]; + scalar_t ai = a[i]; + scalar_t bj = b[j]; + intersection += ai == bj; + i += ai < bj; + j += ai >= bj; } - return 1 - intersection / (a_length + b_length - intersection); + return 1 - static_cast(intersection) / (a_length + b_length - intersection); } }; @@ -1238,7 +1544,10 @@ template struct metric_ } }; -struct cos_i8_t { +/** + * @brief Cosine (Angular) distance for signed 8-bit integers using 16-bit intermediates. + */ +struct metric_cos_i8_t { using scalar_t = i8_t; using result_t = f32_t; @@ -1264,7 +1573,11 @@ struct cos_i8_t { } }; -struct l2sq_i8_t { +/** + * @brief Squared Euclidean (L2) distance for signed 8-bit integers using 16-bit intermediates. + * Square root is avoided at the end, as it won't affect the ordering. + */ +struct metric_l2sq_i8_t { using scalar_t = i8_t; using result_t = f32_t; @@ -1290,7 +1603,8 @@ struct l2sq_i8_t { template struct metric_haversine_gt { using scalar_t = scalar_at; using result_t = result_at; - static_assert(!std::is_integral::value, "Latitude and longitude must be floating-node"); + static_assert(!std::is_integral::value && !std::is_same::value, + "Latitude and longitude must be floating-node"); inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { result_t lat_a = a[0], lon_a = a[1]; @@ -1343,9 +1657,9 @@ class metric_punned_t { /// Distance function that takes two arrays and some callback state and returns a scalar. using metric_array_array_state_t = result_t (*)(uptr_t, uptr_t, uptr_t); /// Distance function callback, like `metric_array_array_size_t`, but depends on member variables. - using metric_rounted_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const; + using metric_routed_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const; - metric_rounted_t metric_routed_ = nullptr; + metric_routed_t metric_routed_ = nullptr; uptr_t metric_ptr_ = 0; uptr_t metric_third_arg_ = 0; @@ -1439,7 +1753,7 @@ class metric_punned_t { } /** - * @brief Creates a metric using the provided function pointer for a statefull metric. + * @brief Creates a metric using the provided function pointer for a stateful metric. * The third argument is the state that will be passed to the metric function. * * @param metric_uintptr The function pointer to the metric function. @@ -1448,9 +1762,9 @@ class metric_punned_t { * @param scalar_kind The kind of scalar to use. * @return A metric object that can be used to compute distances between vectors. */ - inline static metric_punned_t statefull(std::uintptr_t metric_uintptr, std::uintptr_t metric_state, - metric_kind_t metric_kind = metric_kind_t::unknown_k, - scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept { + inline static metric_punned_t stateful(std::uintptr_t metric_uintptr, std::uintptr_t metric_state, + metric_kind_t metric_kind = metric_kind_t::unknown_k, + scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept { metric_punned_t metric; metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; metric.metric_ptr_ = metric_uintptr; @@ -1467,7 +1781,7 @@ class metric_punned_t { inline explicit operator bool() const noexcept { return metric_routed_ && metric_ptr_; } /** - * @brief Checks fi we've failed to initialized the metric with provided arguments. + * @brief Checks if we've failed to initialize the metric with provided arguments. * * It's different from `operator bool()` when it comes to explicitly uninitialized metrics. * It's a common case, where a NULL state is created only to be overwritten later, when @@ -1483,10 +1797,17 @@ class metric_punned_t { switch (isa_kind_) { case simsimd_cap_serial_k: return "serial"; case simsimd_cap_neon_k: return "neon"; + case simsimd_cap_neon_i8_k: return "neon_i8"; + case simsimd_cap_neon_f16_k: return "neon_f16"; + case simsimd_cap_neon_bf16_k: return "neon_bf16"; case simsimd_cap_sve_k: return "sve"; + case simsimd_cap_sve_i8_k: return "sve_i8"; + case simsimd_cap_sve_f16_k: return "sve_f16"; + case simsimd_cap_sve_bf16_k: return "sve_bf16"; case simsimd_cap_haswell_k: return "haswell"; case simsimd_cap_skylake_k: return "skylake"; case simsimd_cap_ice_k: return "ice"; + case simsimd_cap_genoa_k: return "genoa"; case simsimd_cap_sapphire_k: return "sapphire"; default: return "unknown"; } @@ -1521,6 +1842,7 @@ class metric_punned_t { case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; + case scalar_kind_t::bf16_k: datatype = simsimd_datatype_bf16_k; break; case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; default: break; @@ -1533,8 +1855,8 @@ class metric_punned_t { std::memcpy(&metric_ptr_, &simd_metric, sizeof(simd_metric)); metric_routed_ = metric_kind_ == metric_kind_t::ip_k - ? reinterpret_cast(&metric_punned_t::invoke_simsimd_reverse) - : reinterpret_cast(&metric_punned_t::invoke_simsimd); + ? reinterpret_cast(&metric_punned_t::invoke_simsimd_reverse) + : reinterpret_cast(&metric_punned_t::invoke_simsimd); isa_kind_ = simd_kind; return true; } @@ -1542,9 +1864,14 @@ class metric_punned_t { static simsimd_capability_t static_capabilities = simsimd_capabilities(); return configure_with_simsimd(static_capabilities); } - result_t invoke_simsimd(uptr_t a, uptr_t b) const noexcept { + +#if defined(USEARCH_DEFINED_CLANG) || defined(USEARCH_DEFINED_GCC) + __attribute__((no_sanitize("all"))) +#endif + result_t + invoke_simsimd(uptr_t a, uptr_t b) const noexcept { simsimd_distance_t result; - // Here `reinterpret_cast` raises warning... we know what we are doing! + // Here `reinterpret_cast` raises warning and UBSan reports an issue... we know what we are doing! auto function_pointer = (simsimd_metric_punned_t)(metric_ptr_); function_pointer(reinterpret_cast(a), reinterpret_cast(b), metric_third_arg_, &result); @@ -1568,9 +1895,10 @@ class metric_punned_t { switch (metric_kind_) { case metric_kind_t::ip_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; - case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; default: metric_ptr_ = 0; break; } @@ -1578,9 +1906,10 @@ class metric_punned_t { } case metric_kind_t::cos_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_; break; case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; - case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; default: metric_ptr_ = 0; break; } @@ -1588,9 +1917,10 @@ class metric_punned_t { } case metric_kind_t::l2sq_k: { switch (scalar_kind_) { - case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_; break; case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; - case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; default: metric_ptr_ = 0; break; } @@ -1598,6 +1928,9 @@ class metric_punned_t { } case metric_kind_t::pearson_k: { switch (scalar_kind_) { + case scalar_kind_t::bf16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; @@ -1608,9 +1941,8 @@ class metric_punned_t { } case metric_kind_t::haversine_k: { switch (scalar_kind_) { - case scalar_kind_t::f16_k: - metric_ptr_ = (uptr_t)&equidimensional_>; - break; + case scalar_kind_t::bf16_k: metric_ptr_ = 0; break; //< Half-precision 2D vectors are silly. + case scalar_kind_t::f16_k: metric_ptr_ = 0; break; //< Half-precision 2D vectors are silly. case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; default: metric_ptr_ = 0; break; @@ -1619,6 +1951,9 @@ class metric_punned_t { } case metric_kind_t::divergence_k: { switch (scalar_kind_) { + case scalar_kind_t::bf16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; @@ -1643,36 +1978,42 @@ class metric_punned_t { } }; +/* Allow complaining about vectorization after this point. */ +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + /** * @brief View over a potentially-strided memory buffer, containing a row-major matrix. */ template // -class vectors_view_gt { +class matrix_slice_gt { using scalar_t = scalar_at; + using byte_addressable_t = typename std::conditional::value, byte_t const, byte_t>::type; - scalar_t const* begin_{}; + scalar_t* begin_{}; std::size_t dimensions_{}; std::size_t count_{}; std::size_t stride_bytes_{}; public: - vectors_view_gt() noexcept = default; - vectors_view_gt(vectors_view_gt const&) noexcept = default; - vectors_view_gt& operator=(vectors_view_gt const&) noexcept = default; + matrix_slice_gt() noexcept = default; + matrix_slice_gt(matrix_slice_gt const&) noexcept = default; + matrix_slice_gt& operator=(matrix_slice_gt const&) noexcept = default; - vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count = 1) noexcept - : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} + matrix_slice_gt(scalar_t* begin, std::size_t dimensions, std::size_t count = 1) noexcept + : matrix_slice_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} - vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept + matrix_slice_gt(scalar_t* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} explicit operator bool() const noexcept { return begin_; } std::size_t size() const noexcept { return count_; } std::size_t dimensions() const noexcept { return dimensions_; } std::size_t stride() const noexcept { return stride_bytes_; } - scalar_t const* data() const noexcept { return begin_; } - scalar_t const* at(std::size_t i) const noexcept { - return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); + scalar_t* data() const noexcept { return begin_; } + scalar_t* at(std::size_t i) const noexcept { + return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); } }; @@ -1681,7 +2022,7 @@ struct exact_offset_and_distance_t { f32_t distance; }; -using exact_search_results_t = vectors_view_gt; +using exact_search_results_t = matrix_slice_gt; /** * @brief Helper-structure for exact search operations. @@ -1702,9 +2043,9 @@ class exact_search_t { public: template - exact_search_results_t operator()( // - vectors_view_gt dataset, vectors_view_gt queries, // - std::size_t wanted, metric_punned_t const& metric, // + exact_search_results_t operator()( // + matrix_slice_gt dataset, matrix_slice_gt queries, // + std::size_t wanted, metric_punned_t const& metric, // executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { return operator()( // metric, // @@ -1873,7 +2214,7 @@ class flat_hash_multi_set_gt { // Allocate new memory data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); if (!data_) - throw std::bad_alloc(); + __usearch_raise_runtime_error("failed memory allocation"); // Copy metadata buckets_ = other.buckets_; @@ -1913,7 +2254,7 @@ class flat_hash_multi_set_gt { // Allocate new memory data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); if (!data_) - throw std::bad_alloc(); + __usearch_raise_runtime_error("failed memory allocation"); // Copy metadata buckets_ = other.buckets_; @@ -1953,6 +2294,7 @@ class flat_hash_multi_set_gt { clear(); // Clear all elements if (data_) allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + data_ = nullptr; buckets_ = 0; populated_slots_ = 0; capacity_slots_ = 0; @@ -2285,7 +2627,7 @@ class flat_hash_multi_set_gt { void reserve(std::size_t capacity) { if (!try_reserve(capacity)) - throw std::bad_alloc(); + __usearch_raise_runtime_error("failed to reserve memory"); } bool try_emplace(element_t const& element) noexcept { @@ -2315,3 +2657,9 @@ class flat_hash_multi_set_gt { } // namespace usearch } // namespace unum + +// This file is part of the usearch inline third-party dependency of YugabyteDB. +// Git repo: https://github.com/unum-cloud/usearch +// Git commit: 240fe9c298100f9e37a2d7377b1595be6ba1f412 +// +// See also src/inline-thirdparty/README.md.