diff --git a/CMakeLists.txt b/CMakeLists.txt index 911305b32af2..ac6a74bd0ef1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -607,6 +607,7 @@ include_directories(src) include_directories("src/inline-thirdparty/usearch") include_directories("src/inline-thirdparty/fp16") +include_directories("src/inline-thirdparty/hnswlib") enable_testing() diff --git a/src/yb/common/vector_types.h b/src/yb/common/vector_types.h index 3d6867e8f38a..67fe36e93e83 100644 --- a/src/yb/common/vector_types.h +++ b/src/yb/common/vector_types.h @@ -18,6 +18,7 @@ #include #include + #include "yb/gutil/integral_types.h" #include "yb/gutil/macros.h" @@ -28,18 +29,4 @@ using Int32Vector = std::vector; using UInt64Vector = std::vector; using UInt8Vector = std::vector; -// This MUST match the Vector struct definition in -// src/postgres/third-party-extensions/pgvector/src/vector.h. -struct YSQLVector { - // Commented out as this field is not transferred over the wire for all - // Varlens. - // int32 vl_len_; /* varlena header (do not touch directly!) */ - int16 dim; /* number of dimensions */ - int16 unused; - float elems[]; - - private: - DISALLOW_COPY_AND_ASSIGN(YSQLVector); -}; - } // namespace yb diff --git a/src/yb/docdb/pgsql_operation.cc b/src/yb/docdb/pgsql_operation.cc index ea662578bd0a..15f1a9a87f2e 100644 --- a/src/yb/docdb/pgsql_operation.cc +++ b/src/yb/docdb/pgsql_operation.cc @@ -1953,7 +1953,7 @@ Result PgsqlReadOperation::ExecuteVectorSearch( auto query_vec = request_.vector_idx_options().vector().binary_value(); - auto ysql_query_vec = pointer_cast(query_vec.data()); + auto ysql_query_vec = pointer_cast(query_vec.data()); SCHECK_EQ(ysql_query_vec->dim, dims, InvalidArgument, "Vector dimensions mismatch"); @@ -1999,7 +1999,7 @@ Result PgsqlReadOperation::ExecuteVectorSearch( if (!vec_value.has_value()) continue; // Add the vector to the ANN store auto vec = VERIFY_RESULT(VectorANN::GetVectorFromYSQLWire( - *pointer_cast(vec_value->binary_value().data()), + *pointer_cast(vec_value->binary_value().data()), vec_value->binary_value().size())); auto doc_iter = down_cast(table_iter_.get()); ann_store->Add(vec, doc_iter->GetRowKey()); diff --git a/src/yb/tools/CMakeLists.txt b/src/yb/tools/CMakeLists.txt index 4b84a3e3514a..0be9dd4666f3 100644 --- a/src/yb/tools/CMakeLists.txt +++ b/src/yb/tools/CMakeLists.txt @@ -237,6 +237,4 @@ ADD_YB_TEST_LIBRARY(yb-backup-test_base DEPS ${YB_BACKUP_TEST_BASE_DEPS}) add_executable(hnsw_tool hnsw_tool.cc) -# We use yb_test_util here for TestThreadHolder. -# hnsw_tool is a test tool in a way, so this is OK. -target_link_libraries(hnsw_tool boost_program_options yb_util yb_docdb yb_vector yb_test_util) +target_link_libraries(hnsw_tool boost_program_options yb_util yb_docdb yb_vector) diff --git a/src/yb/tools/hnsw_tool.cc b/src/yb/tools/hnsw_tool.cc index d29b871f5b65..f28d3626b931 100644 --- a/src/yb/tools/hnsw_tool.cc +++ b/src/yb/tools/hnsw_tool.cc @@ -25,18 +25,24 @@ #include "yb/util/string_util.h" #include "yb/util/test_thread_holder.h" -#include "yb/vector/hnsw_options.h" -#include "yb/vector/benchmark_data.h" +#include "yb/vector/ann_methods.h" #include "yb/vector/ann_validation.h" -#include "yb/vector/graph_repr_defs.h" -#include "yb/vector/usearch_wrapper.h" +#include "yb/vector/benchmark_data.h" #include "yb/vector/distance.h" +#include "yb/vector/graph_repr_defs.h" +#include "yb/vector/hnsw_options.h" #include "yb/vector/hnsw_util.h" +#include "yb/vector/hnswlib_wrapper.h" +#include "yb/vector/sharded_index.h" +#include "yb/vector/usearch_wrapper.h" +#include "yb/vector/vector_index_wrapper_util.h" #include "yb/tools/tool_arguments.h" namespace po = boost::program_options; +using namespace std::literals; + namespace yb::tools { // Rather than constantly adding needed identifiers from the vectorindex namespace here, it seem to @@ -82,10 +88,12 @@ struct BenchmarkArguments { size_t num_threads = 0; size_t num_validation_queries = 0; size_t num_vectors_to_insert = 0; - size_t report_num_keys = 1000; + size_t report_num_keys = 2500; + size_t num_index_shards = 1; std::string build_vecs_path; std::string ground_truth_path; std::string query_vecs_path; + ANNMethodKind ann_method; std::string ToString() const { return YB_STRUCT_TO_STRING( @@ -127,6 +135,10 @@ std::unique_ptr BenchmarkOptions() { BOOST_PP_STRINGIZE(field_name), po::value(&args.field_name)->default_value( \ args.field_name) +#define OPTIONAL_ARG_FIELD_WITH_LOWER_BOUND(field_name, lower_bound) \ + OPTIONAL_ARG_FIELD(field_name)->notifier( \ + OptionLowerBound(BOOST_PP_STRINGIZE(field_name), lower_bound)) + #define BOOL_SWITCH_ARG_FIELD(field_name) \ BOOST_PP_STRINGIZE(field_name), po::bool_switch(&args.field_name) @@ -138,7 +150,16 @@ std::unique_ptr BenchmarkOptions() { #define HNSW_OPTION_BOOL_ARG(field_name) \ BOOST_PP_STRINGIZE(field_name), po::bool_switch(&args.hnsw_options.field_name) + const auto ann_method_help = + Format("Approximate nearest neighbor search method to use. Possible values: $0.", + ValidEnumValuesCommaSeparatedForHelp()); + const auto distance_kind_help = + Format("What kind of distance function (metric) to use. Possible values: $0." + + ValidEnumValuesCommaSeparatedForHelp()); + result->desc.add_options() + (OPTIONAL_ARG_FIELD(ann_method), + ann_method_help.c_str() /* Boost copies the string internally */) (OPTIONAL_ARG_FIELD(num_vectors_to_insert), "Number of vectors to use for building the index. This is used if no input file is " "specified.") @@ -149,8 +170,6 @@ std::unique_ptr BenchmarkOptions() { (OPTIONAL_ARG_FIELD(ground_truth_path), "Input file containing integer vectors of correct nearest neighbor vector identifiers " "(0-based in the input dataset) for each query.") - ("input_file_name_fvec", po::value(&args.num_vectors_to_insert), - "Number of randomly generated vectors to add") (OPTIONAL_ARG_FIELD(k), "Number of results to retrieve with each validation query") (OPTIONAL_ARG_FIELD(num_validation_queries), @@ -158,12 +177,11 @@ std::unique_ptr BenchmarkOptions() { ("dimensions", po::value(&args.hnsw_options.dimensions), "Number of dimensions for automatically generated vectors. Required if no input file " "is specified.") - ("report_num_keys", - po::value(&args.report_num_keys)->notifier(OptionLowerBound("report_num_keys", 1)), + (OPTIONAL_ARG_FIELD_WITH_LOWER_BOUND(report_num_keys, 1), "Report progress after each batch of this many keys is inserted. 0 to disable reporting.") (HNSW_OPTION_BOOL_ARG(extend_candidates), "Whether to extend the set of candidates with their neighbors before executing the " - "neihgborhood selection heuristic.") + "neighborhood selection heuristic.") (HNSW_OPTION_BOOL_ARG(keep_pruned_connections), "Whether to keep the maximum number of discarded candidates with the minimum distance to " "the base element in the neighborhood selection heuristic.") @@ -186,6 +204,8 @@ std::unique_ptr BenchmarkOptions() { (HNSW_OPTION_ARG(robust_prune_alpha), "The parameter inspired by DiskANN that controls the neighborhood pruning procedure. " "Higher values result in fewer candidates being pruned. Typically between 1.0 and 1.6.") + (HNSW_OPTION_ARG(distance_kind), + distance_kind_help.c_str()) (OPTIONAL_ARG_FIELD(num_threads), "Number of threads to use for indexing and validation. Defaults to the number of CPU " "cores.") @@ -198,7 +218,10 @@ std::unique_ptr BenchmarkOptions() { "result sets using brute-force precise nearest neighbor search. Could be slow.") (OPTIONAL_ARG_FIELD(max_memory_for_loading_vectors_mb), "Maximum amount of memory to use for loading raw input vectors. Used to avoid memory " - "overflow on large datasets. Specify 0 to disable."); + "overflow on large datasets. Specify 0 to disable.") + (OPTIONAL_ARG_FIELD_WITH_LOWER_BOUND(num_index_shards, 1), + "For experiments that try to take advantage of a large number of cores, this allows to " + "create multiple instances of the vector index and insert into them concurrently."); #undef OPTIONAL_ARG_FIELD #undef BOOL_SWITCH_ARG_FIELD @@ -230,25 +253,47 @@ Result DetermineCoordinateKind(BenchmarkArguments& args) { std::unique_ptr CreateRandomFloatVectorSource( size_t num_vectors, size_t dimensions) { - return vectorindex::CreateUniformRandomVectorSource(num_vectors, dimensions, 0.0f, 1.0f); + return CreateUniformRandomVectorSource(num_vectors, dimensions, 0.0f, 1.0f); } -// We instantiate this template as soon as we determine what coordinate type we are working with. -template +// We instantiate this template as soon as we determine what coordinate type and distance result +// type we are working with. +// +// Because in some cases the input coordinate type is not supoprted by the index implementation, +// we are using separate "input" and "indexed" vector types and distance result types. +template class BenchmarkTool { public: - // Usearch HNSW currently does not support other types of vectors, so we cast the input vectors to - // float for now. See also: https://github.com/unum-cloud/usearch/issues/469 - using HNSWVectorType = FloatVector; - using HNSWImpl = UsearchIndex; - - explicit BenchmarkTool(const BenchmarkArguments& args) : args_(args) {} + explicit BenchmarkTool( + const BenchmarkArguments& args, + std::unique_ptr> index_factory) + : args_(args), + index_factory_(std::move(index_factory)) { + } Status Execute() { - LOG(INFO) << "Uisng input file coordinate type: " << args_.coordinate_kind; - + SCHECK_EQ(args_.coordinate_kind, + CoordinateTypeTraits::kKind, + RuntimeError, + "InputVector template argument does not match the inferred coordinate type"); + + LOG(INFO) << "Using ANN method: " << args_.ann_method; + LOG(INFO) << "Using input file coordinate type: " << args_.coordinate_kind; + LOG(INFO) << "Vector index internally uses the coordinate type: " + << CoordinateTypeTraits::kKind; + + LOG(INFO) << "Using distance result type in the input data: " + << CoordinateTypeTraits::kKind; + LOG(INFO) << "Using distance result type in the index implementation: " + << CoordinateTypeTraits::kKind; + if (args_.num_index_shards > 1) { + LOG(INFO) << "Using " << args_.num_index_shards << " index shards"; + } indexed_vector_source_ = VERIFY_RESULT(CreateVectorSource( - args_.build_vecs_path, "vectors to build index on", args_.num_vectors_to_insert)); + args_.build_vecs_path, "vectors to build index on", args_.num_vectors_to_insert)); query_vector_source_ = VERIFY_RESULT(CreateVectorSource( args_.query_vecs_path, "vectors to query", args_.num_validation_queries)); RETURN_NOT_OK(LoadPrecomputedGroundTruth()); @@ -265,7 +310,14 @@ class BenchmarkTool { PrintConfiguration(); - hnsw_ = std::make_unique(hnsw_options()); + if (args_.num_index_shards > 1) { + index_factory_ = + std::make_unique>( + args_.num_index_shards, std::move(index_factory_)); + } + + index_factory_->SetOptions(hnsw_options()); + vector_index_ = index_factory_->Create(); RETURN_NOT_OK(BuildIndex()); @@ -301,19 +353,19 @@ class BenchmarkTool { return hnsw_options().dimensions; } - Result>> CreateVectorSource( + Result>> CreateVectorSource( const std::string& vectors_file_path, const std::string& description, size_t num_vectors_to_use) { if (!vectors_file_path.empty()) { - auto vec_reader = VERIFY_RESULT(OpenVecsFile(vectors_file_path, description)); + auto vec_reader = VERIFY_RESULT(OpenVecsFile(vectors_file_path, description)); RETURN_NOT_OK(vec_reader->Open()); RETURN_NOT_OK(SetDimensions(vec_reader->dimensions())); return vec_reader; } if (num_vectors_to_use > 0) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return CreateRandomFloatVectorSource(args_.num_validation_queries, dimensions()); } return STATUS(InvalidArgument, @@ -355,7 +407,7 @@ class BenchmarkTool { if (ground_truth_vec.size() != args_.k) { return STATUS_FORMAT( IllegalState, - "Provided ground truth vector has $0 dimensions but the configured number k of top " + "Provided ground truth vector has size of $0 but the configured number k of top " "results is $1", ground_truth_vec.size(), args_.k); } @@ -382,27 +434,38 @@ class BenchmarkTool { } Status Validate() { - std::vector query_vectors; - for (;;) { - auto query = VERIFY_RESULT(query_vector_source_->Next()); - if (query.empty()) { - break; - } - query_vectors.push_back(query); - } + std::vector query_vectors = VERIFY_RESULT(query_vector_source_->LoadVectors()); + + auto distance_fn = GetDistanceFunction( + args_.hnsw_options.distance_kind); + + auto vertex_id_to_query_distance_fn = + [this, &distance_fn](VertexId vertex_id, const InputVector& v) -> InputDistanceResult { + // Avoid vector_cast on the critical path of the brute force search here. + return distance_fn(input_vectors_[VertexIdToInputVectorIndex(vertex_id)], v); + }; - std::vector float_query_vectors = ToFloatVectorOfVectors(query_vectors); + VectorIndexReaderIf* reader; + using Adapter = VectorIndexReaderAdapter< + IndexedVector, IndexedDistanceResult, InputVector, InputDistanceResult>; + std::optional adapter; - vectorindex::GroundTruth ground_truth( - [this](VertexId vertex_id, const FloatVector& v) -> float { - const auto& vertex_v = input_vectors_[VertexIdToInputVectorIndex(vertex_id)]; - return distance::DistanceL2Squared(vertex_v, v); - }, + if constexpr (std::is_same_v) { + reader = vector_index_.get(); + } else { + // In case the index uses a different vector type, create an adapter to map the results from + // the indexed type back to the input type. + adapter.emplace(*vector_index_.get()); + reader = &adapter.value(); + } + // The ground truth evaluation is always done in the input coordinate type. + GroundTruth ground_truth( + vertex_id_to_query_distance_fn, args_.k, - float_query_vectors, + query_vectors, loaded_ground_truth_, args_.validate_ground_truth, - *hnsw_, + *reader, // The set of vertex ids to recompute ground truth with. // // In case ground truth is specified as an input file, it must have been computed using all @@ -450,16 +513,23 @@ class BenchmarkTool { auto elapsed_usec = (MonoTime::Now() - load_start_time).ToMicroseconds(); double n_log_n_constant = elapsed_usec * 1.0 / num_inserted / log(num_inserted); double elapsed_time_sec = elapsed_usec / 1000000.0; + size_t remaining_points = max_num_vectors_to_insert() - num_inserted; + auto keys_per_sec = num_inserted / elapsed_time_sec; LOG(INFO) << "n: " << num_inserted << ", " - << "elapsed time (seconds): " << elapsed_time_sec << ", " + << "elapsed time: " << StringPrintf("%.1f", elapsed_time_sec) << " sec, " << "O(n*log(n)) constant: " << n_log_n_constant << ", " - << "keys per second: " << (num_inserted / elapsed_time_sec); + << "remaining points: " << remaining_points << ", " + << "keys per second: " << static_cast(keys_per_sec) << ", " + << "time remaining: " + << StringPrintf("%.1f", keys_per_sec > 0 ? remaining_points / keys_per_sec : 0) + << " sec"; } Status PrepareInputVectors() { size_t num_vectors_to_load = max_num_vectors_to_insert(); double total_mem_required_mb = - num_vectors_to_load * sizeof(typename Vector::value_type) * dimensions() / 1024.0 / 1024; + num_vectors_to_load * sizeof(typename InputVector::value_type) * + dimensions() / 1024.0 / 1024; if (args_.max_memory_for_loading_vectors_mb != 0 && total_mem_required_mb > args_.max_memory_for_loading_vectors_mb) { return STATUS_FORMAT( @@ -491,12 +561,7 @@ class BenchmarkTool { Status InsertOneVector(VertexId vertex_id, MonoTime load_start_time) { const auto& v = GetVectorByVertexId(vertex_id); - Status s; - if constexpr (std::is_same::value) { - s = hnsw_->Insert(vertex_id, v); - } else { - s = hnsw_->Insert(vertex_id, ToFloatVector(v)); - } + Status s = vector_index_->Insert(vertex_id, vector_cast(v)); if (s.ok()) { auto new_num_inserted = num_vectors_inserted_.fetch_add(1, std::memory_order_acq_rel) + 1; ReportIndexingProgress(load_start_time, new_num_inserted); @@ -533,7 +598,7 @@ class BenchmarkTool { } Status BuildIndex() { - hnsw_->Reserve(num_points_to_insert()); + RETURN_NOT_OK(vector_index_->Reserve(num_points_to_insert())); return InsertVectors(); } @@ -566,7 +631,7 @@ class BenchmarkTool { return index; } - const Vector& GetVectorByVertexId(VertexId vertex_id) { + const InputVector& GetVectorByVertexId(VertexId vertex_id) { auto vector_index = VertexIdToInputVectorIndex(vertex_id); return input_vectors_[vector_index]; } @@ -574,12 +639,13 @@ class BenchmarkTool { BenchmarkArguments args_; // Source from which we take vectors to build the index on. - std::unique_ptr> indexed_vector_source_; + std::unique_ptr> indexed_vector_source_; // Source for vectors to run validation queries on. - std::unique_ptr> query_vector_source_; + std::unique_ptr> query_vector_source_; - std::unique_ptr hnsw_; + std::unique_ptr> index_factory_; + std::unique_ptr> vector_index_; // Atomics used in multithreaded index construction. std::atomic num_vectors_inserted_{0}; // Total # vectors inserted. @@ -592,23 +658,89 @@ class BenchmarkTool { std::vector all_vertex_ids_; // Raw input vectors in the order they appeared in the input file. - std::vector input_vectors_; + std::vector input_vectors_; }; +template +std::optional BenchmarkExecuteHelper( + const BenchmarkArguments& args, + CoordinateKind input_coordinate_kind) { + using InputDistanceResult = typename DistanceTraits::Result; + using IndexedDistanceResult = typename DistanceTraits::Result; + if (args.ann_method == ann_method_kind && + args.hnsw_options.distance_kind == distance_kind && + input_coordinate_kind == CoordinateTypeTraits::kKind) { + using IndexFactory = typename ANNMethodTraits::template IndexFactory< + IndexedVector, + typename DistanceTraits::Result>; + return BenchmarkTool( + args, + std::make_unique() + ).Execute(); + } + return std::nullopt; +} + Status BenchmarkExecute(const BenchmarkArguments& args) { auto args_copy = args; args_copy.FinalizeDefaults(); - auto coordinate_kind = VERIFY_RESULT(DetermineCoordinateKind(args_copy)); - switch (coordinate_kind) { - case CoordinateKind::kFloat32: - return BenchmarkTool>(args_copy).Execute(); - case CoordinateKind::kUInt8: - return BenchmarkTool>(args_copy).Execute(); - default: - return STATUS_FORMAT( - InvalidArgument, - "Input files with coordinate type $0 are not supported", coordinate_kind); - } + + LOG(INFO) << "Distance kind: " << args_copy.hnsw_options.distance_kind; + + // The input coordinate type is based on input file extensions. + auto input_coordinate_kind = VERIFY_RESULT(DetermineCoordinateKind(args_copy)); + + // Determining the right template arguments is a bit tricky. We have a few supported combinations + // of the ANN method, distance function, input vector type, and the indexed vector type that the + // method has to use in case the ANN method doesn't support the input vector type. To avoid + // error-prone code duplication, we use a macro that expands to a bunch of if statements. + +#define YB_VECTOR_INDEX_BENCHMARK_SUPPORTED_CASES \ + /* method, distance, input type, indexed type */ \ + /* Euclidean distance */ \ + ((Usearch, L2Squared, float, float )) \ + ((Usearch, L2Squared, uint8_t, float )) \ + ((Hnswlib, L2Squared, float, float )) \ + ((Hnswlib, L2Squared, uint8_t, uint8_t)) \ + /* Cosine similarity */ \ + ((Usearch, Cosine, float, float )) \ + ((Usearch, Cosine, uint8_t, float )) \ + /* Inner product */ \ + ((Usearch, InnerProduct, float, float )) \ + ((Usearch, InnerProduct, uint8_t, float )) \ + ((Hnswlib, InnerProduct, float, float )) \ + ((Hnswlib, InnerProduct, uint8_t, uint8_t)) + +#define YB_VECTOR_INDEX_BENCHMARK_HELPER(method, distance_enum_element, input_type, indexed_type) \ + if (auto status = BenchmarkExecuteHelper< \ + ANNMethodKind::BOOST_PP_CAT(k, method), \ + distance_enum_element, \ + std::vector, \ + std::vector>(args_copy, input_coordinate_kind); status.has_value()) { \ + return *status; \ + } + +#define YB_VECTOR_INDEX_BENCHMARK_FOR_EACH_HELPER(r, data, elem) \ + YB_VECTOR_INDEX_BENCHMARK_HELPER( \ + BOOST_PP_TUPLE_ELEM(4, 0, elem), \ + DistanceKind::BOOST_PP_CAT(k, BOOST_PP_TUPLE_ELEM(4, 1, elem)), \ + BOOST_PP_TUPLE_ELEM(4, 2, elem), \ + BOOST_PP_TUPLE_ELEM(4, 3, elem)) + + BOOST_PP_SEQ_FOR_EACH(YB_VECTOR_INDEX_BENCHMARK_FOR_EACH_HELPER, _, + YB_VECTOR_INDEX_BENCHMARK_SUPPORTED_CASES) + + return STATUS_FORMAT( + InvalidArgument, + "Unsupported combination of ANN method $0, distance kind $1, and input coordinate type $2", + args_copy.ann_method, + args_copy.hnsw_options.distance_kind, + input_coordinate_kind); + + return Status::OK(); } YB_TOOL_ARGUMENTS(HnswAction, HNSW_ACTIONS); diff --git a/src/yb/tools/tool_arguments.h b/src/yb/tools/tool_arguments.h index 84aec42f76fb..a645b0dccab8 100644 --- a/src/yb/tools/tool_arguments.h +++ b/src/yb/tools/tool_arguments.h @@ -13,6 +13,10 @@ #pragma once +#include +#include + +#include #include #include @@ -252,7 +256,7 @@ Status CommonHelpExecute(const CommonHelpArguments& args) { } // ------------------------------------------------------------------------------------------------ -// Helpers for specifying valid ranges of options +// Helpers for specifying valid ranges of options, and various custom types of options // ------------------------------------------------------------------------------------------------ template @@ -269,5 +273,36 @@ auto OptionLowerBound(const char* option_name, OptionType lower_bound) -> }; } +template +std::string ValidEnumValuesCommaSeparatedForHelp() { + std::vector string_values; + bool all_start_with_k = true; // Initialize to true to check if all start with 'k' + + // Collect the enum values as strings and check the starting character + for (auto element : List(static_cast(nullptr))) { + const auto s = ToString(element); + string_values.push_back(s); + if (s.size() <= 1 || s[0] != 'k') { + all_start_with_k = false; + } + } + + std::vector final_values; + + if (all_start_with_k) { + // If all start with 'k', modify each string by stripping 'k' + final_values.reserve(string_values.size()); + for (auto& s : string_values) { + final_values.push_back(s.substr(1)); // Move stripped string + } + } else { + // If not all start with 'k', move the entire vector + final_values = std::move(string_values); + } + + // Use Boost to join the final values into a comma-separated string + return boost::algorithm::join(final_values, ", "); +} + } // namespace tools } // namespace yb diff --git a/src/yb/util/CMakeLists.txt b/src/yb/util/CMakeLists.txt index b636cf08ba7d..c67ec66a7fd2 100644 --- a/src/yb/util/CMakeLists.txt +++ b/src/yb/util/CMakeLists.txt @@ -429,6 +429,7 @@ ADD_YB_TEST(user-test) ADD_YB_TEST(uuid-test) ADD_YB_TEST(varint-test) ADD_YB_TEST(write_buffer-test) +ADD_YB_TEST(enums-test) ####################################### # jsonwriter_test_proto diff --git a/src/yb/util/enums-test.cc b/src/yb/util/enums-test.cc new file mode 100644 index 000000000000..6b4541e1752d --- /dev/null +++ b/src/yb/util/enums-test.cc @@ -0,0 +1,57 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/util/enums.h" + +#include "yb/util/test_util.h" + +namespace yb { + +class EnumsTest : public YBTest { +}; + +YB_DEFINE_ENUM( + TestEnum, + (kElement1) + (kItem2) + (kWidget3)) + +TEST_F(EnumsTest, FromStream) { + std::vector> test_cases{ + {"kelement1", TestEnum::kElement1}, + {"keleMent1", TestEnum::kElement1}, + {"Element1", TestEnum::kElement1}, + {"kitem2", TestEnum::kItem2}, + {"kItEm2", TestEnum::kItem2}, + {"IteM2", TestEnum::kItem2}, + {"kwidGet3", TestEnum::kWidget3}, + {"kwiDget3", TestEnum::kWidget3}, + {"Widget3", TestEnum::kWidget3}, + }; + + for (const auto& [input_str, expected_enum] : test_cases) { + SCOPED_TRACE("input_str: " + input_str); + std::istringstream input_stream(input_str); + TestEnum value; + input_stream >> value; + ASSERT_FALSE(input_stream.fail()); + ASSERT_EQ(value, expected_enum); + } + + std::istringstream invalid_input1("foo"); + TestEnum value5; + invalid_input1 >> value5; + ASSERT_TRUE(invalid_input1.fail()); +} + +} // namespace yb diff --git a/src/yb/util/enums.h b/src/yb/util/enums.h index 11734f34a96a..bf7d4f810dc9 100644 --- a/src/yb/util/enums.h +++ b/src/yb/util/enums.h @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -30,6 +31,7 @@ #include "yb/util/math_util.h" // For constexpr_max #include "yb/util/result.h" +#include "yb/util/string_util.h" namespace yb { @@ -104,10 +106,6 @@ class AllEnumItemsIterable { BOOST_PP_TUPLE_ELEM(2, 0, data):: \ BOOST_PP_CAT(BOOST_PP_APPLY(BOOST_PP_TUPLE_ELEM(2, 1, data)), YB_ENUM_ITEM_NAME(elem)) -#define YB_ENUM_LIST_ITEM(s, data, elem) \ - BOOST_PP_TUPLE_ELEM(2, 0, data):: \ - BOOST_PP_CAT(BOOST_PP_APPLY(BOOST_PP_TUPLE_ELEM(2, 1, data)), YB_ENUM_ITEM_NAME(elem)) - #define YB_ENUM_CASE_NAME(s, data, elem) \ case BOOST_PP_TUPLE_ELEM(2, 0, data):: \ BOOST_PP_CAT(BOOST_PP_APPLY(BOOST_PP_TUPLE_ELEM(2, 1, data)), YB_ENUM_ITEM_NAME(elem)): \ @@ -133,7 +131,6 @@ class AllEnumItemsIterable { } \ return nullptr; \ } \ - \ inline __attribute__((unused)) std::string ToString(enum_name value) { \ const char* c_str = ToCString(value); \ if (c_str != nullptr) \ @@ -144,7 +141,10 @@ class AllEnumItemsIterable { inline __attribute__((unused)) std::ostream& operator<<(std::ostream& out, enum_name value) { \ return out << ToString(value); \ } \ - \ + inline __attribute__((unused)) std::istream& operator>>(std::istream& in, enum_name& value) { \ + ::yb::detail::EnumFromInputStreamHelper(in, value); \ + return in; \ + } \ constexpr __attribute__((unused)) size_t BOOST_PP_CAT(kElementsIn, enum_name) = \ BOOST_PP_SEQ_SIZE(list); \ constexpr __attribute__((unused)) size_t BOOST_PP_CAT(k, BOOST_PP_CAT(enum_name, MapSize)) = \ @@ -436,4 +436,30 @@ Result ParseEnumInsensitive(const std::string& str) { return ParseEnumInsensitive(str.c_str()); } + +namespace detail { + +template +void EnumFromInputStreamHelper(std::istream& in, EnumType& value) { + std::string token; + in >> token; + if (in.fail()) { + return; + } + auto parse_result = ParseEnumInsensitive(token); + if (parse_result.ok()) { + value = parse_result.get(); + return; + } + // The vast majority of enums are defined with kFoo, kBar, etc. as their values. + parse_result = ParseEnumInsensitive("k" + token); + if (parse_result.ok()) { + value = parse_result.get(); + return; + } + in.setstate(std::ios_base::failbit); +} + +} // namespace detail + } // namespace yb diff --git a/src/yb/vector/CMakeLists.txt b/src/yb/vector/CMakeLists.txt index 786cc124adc4..400be2a4481f 100644 --- a/src/yb/vector/CMakeLists.txt +++ b/src/yb/vector/CMakeLists.txt @@ -18,9 +18,9 @@ set(YB_PCH_PREFIX vector) set(VECTOR_SRCS ann_validation.cc benchmark_data.cc - distance.cc hnsw_options.cc hnsw_util.cc + hnswlib_wrapper.cc usearch_wrapper.cc vectorann.cc ) diff --git a/src/yb/vector/ann_methods.h b/src/yb/vector/ann_methods.h new file mode 100644 index 000000000000..deb39975e3af --- /dev/null +++ b/src/yb/vector/ann_methods.h @@ -0,0 +1,51 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#pragma once + +#include "yb/util/enums.h" +#include "yb/util/result.h" + +#include "yb/vector/coordinate_types.h" +#include "yb/vector/hnsw_options.h" +#include "yb/vector/hnswlib_wrapper.h" +#include "yb/vector/usearch_wrapper.h" +#include "yb/vector/vector_index_if.h" + +namespace yb::vectorindex { + +YB_DEFINE_ENUM( + ANNMethodKind, + (kUsearch) + (kHnswlib)); + +template +struct ANNMethodTraits { + static constexpr ANNMethodKind kKind = method_kind; +}; + +// TODO: use preprocessing to reduce duplication below. + +template<> +struct ANNMethodTraits { + template + using IndexFactory = UsearchIndexFactory; +}; + +template<> +struct ANNMethodTraits { + template + using IndexFactory = HnswlibIndexFactory; +}; + +} // namespace yb::vectorindex diff --git a/src/yb/vector/ann_validation.cc b/src/yb/vector/ann_validation.cc index 4fa7f0c73b69..857aa92ef581 100644 --- a/src/yb/vector/ann_validation.cc +++ b/src/yb/vector/ann_validation.cc @@ -25,9 +25,9 @@ namespace yb::vectorindex { -template -GroundTruth::GroundTruth( - const VertexIdToVectorDistanceFunction& distance_fn, +template +GroundTruth::GroundTruth( + const VertexIdToVectorDistanceFunction& distance_fn, size_t k, const std::vector& queries, const PrecomputedGroundTruthMatrix& precomputed_ground_truth, @@ -46,8 +46,12 @@ GroundTruth::GroundTruth( } } -template -Result GroundTruth::EvaluateRecall(size_t num_threads) { +template +Result GroundTruth::EvaluateRecall( + size_t num_threads) { + SCHECK_GE(num_threads, static_cast(1), + InvalidArgument, "Number of threads must be at least 1"); + // For each index i we maintain the total overlap (set intersection cardinality) between the // top (i + 1) ground truth results and the top (i + 1) ANN results, computed across all queries. // We apply denominators at the end. @@ -56,9 +60,6 @@ Result GroundTruth::EvaluateRecall(size_t num_threads) { total_overlap[i].store(0); } - SCHECK_GE(num_threads, static_cast(1), - InvalidArgument, "Number of threads must be at least 1"); - RETURN_NOT_OK(ProcessInParallel( num_threads, /* start_index= */ static_cast(0), @@ -81,8 +82,8 @@ Result GroundTruth::EvaluateRecall(size_t num_threads) { return result; } -template -Status GroundTruth::ProcessQuery( +template +Status GroundTruth::ProcessQuery( size_t query_index, AtomicUInt64Vector& total_overlap_counters) { auto& query = queries_[query_index]; @@ -93,10 +94,10 @@ Status GroundTruth::ProcessQuery( return Status::OK(); } - // At this point, we need to compute precise results either because precomputed results are not // available, or because we want to validate those precomputed results. - auto our_correct_top_k = BruteForcePreciseNearestNeighbors(query, vertex_ids_, distance_fn_, k_); + auto our_correct_top_k = BruteForcePreciseNearestNeighbors( + query, vertex_ids_, distance_fn_, k_); if (!precomputed_ground_truth_.empty() && validate_precomputed_ground_truth_) { // Compare the ground truth we've just computed to the the precomputed ground truth. @@ -120,24 +121,26 @@ Status GroundTruth::ProcessQuery( return Status::OK(); } -template -VerticesWithDistances GroundTruth::AugmentWithDistances( +template +VerticesWithDistances +GroundTruth::AugmentWithDistances( const std::vector& vertex_ids, const Vector& query) { - VerticesWithDistances result; + VerticesWithDistances result; result.reserve(vertex_ids.size()); for (auto vertex_id : vertex_ids) { - result.push_back(VertexWithDistance(vertex_id, distance_fn_(vertex_id, query))); + result.push_back(VertexWithDistance( + vertex_id, distance_fn_(vertex_id, query))); } return result; } -template -void GroundTruth::DoApproxSearchAndUpdateStats( +template +void GroundTruth::DoApproxSearchAndUpdateStats( const Vector& query, const std::vector& correct_result, AtomicUInt64Vector& total_overlap_counters) { - auto approx_result = index_reader_.Search(query, k_); + auto approx_result = index_reader_.Search(vector_cast(query), k_); std::unordered_set approx_set; for (const auto& approx_entry : approx_result) { approx_set.insert(approx_entry.vertex_id); @@ -155,83 +158,8 @@ void GroundTruth::DoApproxSearchAndUpdateStats( } } -bool ResultSetsEquivalent(const VerticesWithDistances& a, const VerticesWithDistances& b) { - if (a.size() != b.size()) { - return false; - } - const auto k = a.size(); - bool matches = true; - for (size_t i = 0; i < k; ++i) { - if (a[i] != b[i]) { - matches = false; - break; - } - } - if (matches) { - return true; - } - - // Sort both result sets by increasing distance, and for the same distance, increasing vertex id. - auto a_sorted = a; - std::sort(a_sorted.begin(), a_sorted.end()); - auto b_sorted = b; - std::sort(b_sorted.begin(), b_sorted.end()); - - size_t discrepancy_index = k; - for (size_t i = 0; i < k; ++i) { - if (a_sorted[i] != b_sorted[i]) { - discrepancy_index = i; - break; - } - } - if (discrepancy_index == k) { - // The arrays became the same after sorting. - return true; - } - - // We allow a situation where vertex ids are different as long as distances are the same until - // the end of both result sets. In that case we still consider the two result sets equivalent. - float expected_distance = a_sorted[discrepancy_index].distance; - for (size_t i = discrepancy_index; i < k; ++i) { - float a_dist = a_sorted[i].distance; - float b_dist = b_sorted[i].distance; - if (a_dist != expected_distance && b_dist != expected_distance) { - return false; - } - } - return true; -} - -std::string ResultSetDifferenceStr(const VerticesWithDistances& a, const VerticesWithDistances& b) { - if (a.size() != b.size()) { - // This should not occur, so no details here. - return Format("Result set size: $0 vs. $1", a.size(), b.size()); - } - const size_t k = a.size(); - - auto a_sorted = a; - std::sort(a_sorted.begin(), a_sorted.end()); - auto b_sorted = b; - std::sort(b_sorted.begin(), b_sorted.end()); - - std::ostringstream diff_str; - - bool found_differences = false; - for (size_t j = 0; j < k; ++j) { - if (a_sorted[j] != b_sorted[j]) { - if (found_differences) { - diff_str << "\n"; - } - found_differences = true; - diff_str << " " << a_sorted[j].ToString() << " vs. " << b_sorted[j].ToString(); - } - } - if (found_differences) { - return diff_str.str(); - } - return "No differences"; -} - -YB_INSTANTIATE_TEMPLATE_FOR_ALL_VECTOR_TYPES(GroundTruth); +template class GroundTruth; +template class GroundTruth; +template class GroundTruth; } // namespace yb::vectorindex diff --git a/src/yb/vector/ann_validation.h b/src/yb/vector/ann_validation.h index 97df0a1ada28..b6d286b8eea0 100644 --- a/src/yb/vector/ann_validation.h +++ b/src/yb/vector/ann_validation.h @@ -37,14 +37,15 @@ using PrecomputedGroundTruthMatrix = std::vector>; // only a minor part of the overall time spent querying the vector index. using AtomicUInt64Vector = std::vector>; -// Computes ground truth of approximate nearest neighbor search. Parameterized by the query. -template +// Computes ground truth of approximate nearest neighbor search. +template class GroundTruth { public: - using IndexReader = VectorIndexReaderIf; + using IndexReader = VectorIndexReaderIf; GroundTruth( - const VertexIdToVectorDistanceFunction& distance_fn, + const VertexIdToVectorDistanceFunction& distance_fn, size_t k, const std::vector& queries, const PrecomputedGroundTruthMatrix& precomputed_ground_truth, @@ -73,10 +74,11 @@ class GroundTruth { const std::vector& correct_result, AtomicUInt64Vector& total_overlap_counters); - VerticesWithDistances AugmentWithDistances( - const std::vector& vertex_ids, const Vector& query); + // This works on queries convertered from input vector io indexed vector format. + VerticesWithDistances AugmentWithDistances( + const std::vector& vertex_ids, const Vector& converted_query); - VertexIdToVectorDistanceFunction distance_fn_; + VertexIdToVectorDistanceFunction distance_fn_; size_t k_; const std::vector& queries_; const PrecomputedGroundTruthMatrix& precomputed_ground_truth_; @@ -91,8 +93,85 @@ class GroundTruth { // same distance to the query. This corresponds to a situation when a group of items with the same // distance to the query, ordered differently in the two result sets, was cut in the middle by the // result set boundary. -bool ResultSetsEquivalent(const VerticesWithDistances& a, const VerticesWithDistances& b); - -std::string ResultSetDifferenceStr(const VerticesWithDistances& a, const VerticesWithDistances& b); +template +bool ResultSetsEquivalent(const VerticesWithDistances& a, + const VerticesWithDistances& b) { + if (a.size() != b.size()) { + return false; + } + const auto k = a.size(); + bool matches = true; + for (size_t i = 0; i < k; ++i) { + if (a[i] != b[i]) { + matches = false; + break; + } + } + if (matches) { + return true; + } + + // Sort both result sets by increasing distance, and for the same distance, increasing vertex id. + auto a_sorted = a; + std::sort(a_sorted.begin(), a_sorted.end()); + auto b_sorted = b; + std::sort(b_sorted.begin(), b_sorted.end()); + + size_t discrepancy_index = k; + for (size_t i = 0; i < k; ++i) { + if (a_sorted[i] != b_sorted[i]) { + discrepancy_index = i; + break; + } + } + if (discrepancy_index == k) { + // The arrays became the same after sorting. + return true; + } + + // We allow a situation where vertex ids are different as long as distances are the same until + // the end of both result sets. In that case we still consider the two result sets equivalent. + float expected_distance = a_sorted[discrepancy_index].distance; + for (size_t i = discrepancy_index; i < k; ++i) { + float a_dist = a_sorted[i].distance; + float b_dist = b_sorted[i].distance; + if (a_dist != expected_distance && b_dist != expected_distance) { + return false; + } + } + return true; +} + +template +std::string ResultSetDifferenceStr(const VerticesWithDistances& a, + const VerticesWithDistances& b) { + if (a.size() != b.size()) { + // This should not occur, so no details here. + return Format("Result set size: $0 vs. $1", a.size(), b.size()); + } + const size_t k = a.size(); + + auto a_sorted = a; + std::sort(a_sorted.begin(), a_sorted.end()); + auto b_sorted = b; + std::sort(b_sorted.begin(), b_sorted.end()); + + std::ostringstream diff_str; + + bool found_differences = false; + for (size_t j = 0; j < k; ++j) { + if (a_sorted[j] != b_sorted[j]) { + if (found_differences) { + diff_str << "\n"; + } + found_differences = true; + diff_str << " " << a_sorted[j].ToString() << " vs. " << b_sorted[j].ToString(); + } + } + if (found_differences) { + return diff_str.str(); + } + return "No differences"; +} } // namespace yb::vectorindex diff --git a/src/yb/vector/benchmark_data.h b/src/yb/vector/benchmark_data.h index 35571d1f47ae..0395f8979f31 100644 --- a/src/yb/vector/benchmark_data.h +++ b/src/yb/vector/benchmark_data.h @@ -247,10 +247,6 @@ class VecsFileReader : public VectorSource { mutable Vector current_vector_; }; -using BvecsFileReader = VecsFileReader>; -using FvecsFileReader = VecsFileReader; -using IvecsFileReader = VecsFileReader; - // Determine coordinate kind by file name (.fvecs/.bvecs/.ivecs). Result GetCoordinateKindFromVecsFileName(const std::string& vecs_file_path); diff --git a/src/yb/vector/coordinate_types.h b/src/yb/vector/coordinate_types.h index 5a836bff6199..8ea2b86d8076 100644 --- a/src/yb/vector/coordinate_types.h +++ b/src/yb/vector/coordinate_types.h @@ -31,7 +31,7 @@ namespace yb::vectorindex { // The usearch counterpart of this is scalar_kind_t in index_plugins.hpp. // Columns: // 1. This goes into enum element naming, e.g. kFloat64 or kFloat32. -// 2. The corresponding (mostly) standard C/C++ data type. +// 2. The corresponding standard C/C++ data type or a ..._t type from stdint.h. // 3. The prefix of what usearch calls this type (e.g. f64_k for the scalar_kind_t enum element and // f64_t for the typedef). This is also a convenient short identifer. #define YB_COORDINATE_TYPE_INFO \ @@ -54,22 +54,6 @@ namespace yb::vectorindex { #define YB_EXTRACT_COORDINATE_TYPE(tuple) BOOST_PP_TUPLE_ELEM(3, 1, tuple) #define YB_EXTRACT_COORDINATE_TYPE_SHORT_NAME(tuple) BOOST_PP_TUPLE_ELEM(3, 2, tuple) -#define YB_EXTRACT_COORDINATE_TYPE_WITH_COMMA(r, data, i, coordinate_info_tuple) \ - BOOST_PP_COMMA_IF(i) YB_EXTRACT_COORDINATE_TYPE(coordinate_info_tuple) - -// Comma-separated list of scalar types -#define YB_COORDINATE_TYPES_COMMA_SEPARATED \ - BOOST_PP_SEQ_FOR_EACH_I(YB_EXTRACT_COORDINATE_TYPE_WITH_COMMA, _, YB_COORDINATE_TYPE_INFO) - -#define YB_VECTOR_TYPE_WITH_COMMA(r, data, i, coordinate_info_tuple) \ - BOOST_PP_COMMA_IF(i) std::vector - -// Comma-separated list of vector types -#define YB_VECTOR_TYPES_COMMA_SEPARATED \ - BOOST_PP_SEQ_FOR_EACH_I(YB_VECTOR_TYPE_WITH_COMMA, _, YB_COORDINATE_TYPE_INFO) - -#undef YB_MAKE_VECTOR_TYPE_WITH_COMMA - // ------------------------------------------------------------------------------------------------ // CoordinateKind enum // ------------------------------------------------------------------------------------------------ @@ -132,12 +116,13 @@ struct CoordinateTypeTraits { using Vector = std::vector; }; -#define YB_DEFINE_COORDINATE_TYPE_TRAITS(capitalized_name, scalar_type_name, short_type_name) \ +#define YB_DEFINE_COORDINATE_TYPE_TRAITS( \ + capitalized_name, \ + scalar_type_name, \ + short_type_name) \ template <> \ struct CoordinateTypeTraits { \ - static constexpr CoordinateKind Kind() { \ - return CoordinateKind::BOOST_PP_CAT(k, capitalized_name); \ - } \ + static constexpr CoordinateKind kKind = CoordinateKind::BOOST_PP_CAT(k, capitalized_name); \ static constexpr const char* ShortTypeNameStr() { \ /* Short type name such as f32, u8, etc. */ \ return BOOST_PP_STRINGIZE(short_type_name); \ @@ -194,4 +179,37 @@ ReturnType HandleCoordinateKindSwitch(CoordinateKind coordinate_kind, Functor&& template_name, \ YB_COORDINATE_TYPE_INFO) + +// ------------------------------------------------------------------------------------------------ +// Vector cast +// ------------------------------------------------------------------------------------------------ + +template +ToVector vector_cast(const FromVector& from_vector) { + if constexpr (std::is_same_v) { + return from_vector; + } + ToVector to_vector; + to_vector.reserve(from_vector.size()); + for (const auto& from_element : from_vector) { + to_vector.push_back(static_cast(from_element)); + } + return to_vector; +} + +// Cast for vector of vectors +template +auto vector_cast(const std::vector& v) { + if constexpr (std::is_same_v) { + return v; + } + std::vector result; + result.reserve(v.size()); + for (const auto& subvector : v) { + result.push_back(vector_cast(subvector)); + } + return result; +} + + } // namespace yb::vectorindex diff --git a/src/yb/vector/distance.cc b/src/yb/vector/distance.cc deleted file mode 100644 index b88a50d8948e..000000000000 --- a/src/yb/vector/distance.cc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) YugabyteDB, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations -// under the License. -// - -#include "yb/vector/distance.h" - -#include - -#include "yb/util/logging.h" - -namespace yb::vectorindex { - -namespace distance { - -} // namespace distance - -std::vector VertexIdsOnly(const VerticesWithDistances& vertices_with_distances) { - std::vector result; - result.reserve(vertices_with_distances.size()); - for (const auto& v_dist : vertices_with_distances) { - result.push_back(v_dist.vertex_id); - } - return result; - -} -} // namespace yb::vectorindex diff --git a/src/yb/vector/distance.h b/src/yb/vector/distance.h index 5b1c30156636..ba12c4fc488c 100644 --- a/src/yb/vector/distance.h +++ b/src/yb/vector/distance.h @@ -14,8 +14,10 @@ #pragma once #include +#include #include "yb/util/enums.h" +#include "yb/util/tostring.h" #include "yb/common/vector_types.h" #include "yb/vector/graph_repr_defs.h" @@ -23,33 +25,110 @@ namespace yb::vectorindex { +YB_DEFINE_ENUM( + DistanceKind, + + // Squared Euclidean (L2) distance -- sum of squares between coordinate differences. + (kL2Squared) + + // Inner product: 1 - (x dot product y). + (kInnerProduct) + + // Cosine (Angular) distance. Similar to Inner Product, but it has to normalize the vectors + // first, so it is more expensive to compute. + (kCosine)); + +template +concept ValidDistanceResultType = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + +template +struct DistanceTraits { + using Scalar = typename Vector::value_type; + + // The "default" distance result is float for 32-bit or smaller scalars, and double for 64-bit + // scalars. + using Result = typename std::conditional<(sizeof(Scalar) <= 4), float, double>::type; + + static_assert(ValidDistanceResultType); +}; + +#define YB_OVERRIDE_DISTANCE_RESULT_TYPE(scalar, distance_kind, distance_result_type) \ + template<> \ + struct DistanceTraits, distance_kind> { \ + using Result = distance_result_type; \ + static_assert(ValidDistanceResultType); \ + }; + +// L2 square distance for signed or unsigned byte vectors could be computed as uint32_t for vectors +// with 33025 or fewer dimensions: (2**31-1)/(255*255) > 33025. If we were to use float here, we +// might run into precision issues. +// +// TODO: temporarily using int32_t for consistency with hnswlib. +YB_OVERRIDE_DISTANCE_RESULT_TYPE(int8_t, DistanceKind::kL2Squared, int32_t) +YB_OVERRIDE_DISTANCE_RESULT_TYPE(uint8_t, DistanceKind::kL2Squared, int32_t) + +// int32_t result type can support up to (2**32-1)/(255*255) > 66051 dimensions when calculating +// inner product of uint8_t vectors. We need a signed result here because we have to negate the +// result of the inner product of two vectors. +YB_OVERRIDE_DISTANCE_RESULT_TYPE(uint8_t, DistanceKind::kInnerProduct, int32_t); + +// int32_t result type can support up to (2**31)/(128*128) = 131072 dimensions when calculating +// inner product of int8_t vectors. +YB_OVERRIDE_DISTANCE_RESULT_TYPE(int8_t, DistanceKind::kInnerProduct, int32_t); + namespace distance { -template -inline float DistanceL2Squared(const Vector1Type& a, const Vector2Type& b) { - float sum = 0; +// Reference implementations of distance functions. These are NOT optimized for performance. + +template +inline auto DistanceL2Squared(const Vector& a, const Vector& b) { + using DistanceResult = typename DistanceTraits::Result; CHECK_EQ(a.size(), b.size()); + DistanceResult sum = 0; for (size_t i = 0; i < a.size(); ++i) { - float diff = a[i] - b[i]; + // Be paranoid about subtraction underflow because DistanceResult might be unsigned. + DistanceResult ai = a[i]; + DistanceResult bi = b[i]; + DistanceResult diff = ai > bi ? ai - bi : bi - ai; sum += diff * diff; } return sum; } template -inline float DistanceCosine(const Vector& a, const Vector& b) { +inline auto DistanceInnerProduct( + const Vector& a, const Vector& b) { + using DistanceResult = typename DistanceTraits::Result; + CHECK_EQ(a.size(), b.size()); + DistanceResult ab = 0; + for (size_t i = 0; i < a.size(); ++i) { + ab += static_cast(a[i]) * static_cast(b[i]); + } + return 1 - ab; +} + +template +inline auto DistanceCosine( + const Vector& a, const Vector& b) { // Adapted from metric_cos_gt in index_plugins.hpp (usearch). + using DistanceResult = typename DistanceTraits::Result; CHECK_EQ(a.size(), b.size()); - float ab = 0, a2 = 0, b2 = 0; + DistanceResult ab = 0; + DistanceResult a2 = 0; + DistanceResult b2 = 0; for (size_t i = 0; i < a.size(); ++i) { - float ai = a[i]; - float bi = b[i]; + DistanceResult ai = a[i]; + DistanceResult bi = b[i]; ab += ai * bi; a2 += ai * ai; b2 += bi * bi; } - float result_if_zero[2][2]; + DistanceResult result_if_zero[2][2]; result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); result_if_zero[0][1] = result_if_zero[1][0] = 1; result_if_zero[1][1] = 0; @@ -59,45 +138,34 @@ inline float DistanceCosine(const Vector& a, const Vector& b) { } // namespace distance -YB_DEFINE_ENUM( - VectorDistanceType, - (kL2Squared) - (kCosine)); - -template -using DistanceFunction = std::function; +template +using DistanceFunction = std::function; // A variant of a distance function that knows how to resolve a vertex id to a vector, and then // compute the distance. -template +template using VertexIdToVectorDistanceFunction = - std::function; - -template -DistanceFunction GetDistanceImpl(VectorDistanceType distance_type) { - switch (distance_type) { - case VectorDistanceType::kL2Squared: - return distance::DistanceL2Squared; - case VectorDistanceType::kCosine: - return distance::DistanceCosine; - } - FATAL_INVALID_ENUM_VALUE(VectorDistanceType, distance_type); -} + std::function; +template struct VertexWithDistance { - VertexId vertex_id = 0; - float distance = 0.0f; + VertexId vertex_id = kInvalidVertexId; + DistanceResult distance{}; - // Deleted constructor to prevent wrong initialization order - VertexWithDistance(float, VertexId) = delete; + // Constructor with the wrong order. Only delete it if DistanceResult is not uint64_t. + template ::value, int>::type = 0> + VertexWithDistance(DistanceResult, VertexId) = delete; VertexWithDistance() = default; // Constructor with the correct order - VertexWithDistance(VertexId vertex_id_, float distance_) + VertexWithDistance(VertexId vertex_id_, DistanceResult distance_) : vertex_id(vertex_id_), distance(distance_) {} - std::string ToString() const; + std::string ToString() const { + return YB_STRUCT_TO_STRING(vertex_id, distance); + } bool operator ==(const VertexWithDistance& other) const { return vertex_id == other.vertex_id && distance == other.distance; @@ -119,8 +187,31 @@ struct VertexWithDistance { } }; -using VerticesWithDistances = std::vector; +template +using VerticesWithDistances = std::vector>; + +template +std::vector VertexIdsOnly( + const VerticesWithDistances& vertices_with_distances) { + std::vector result; + result.reserve(vertices_with_distances.size()); + for (const auto& v_dist : vertices_with_distances) { + result.push_back(v_dist.vertex_id); + } + return result; +} -std::vector VertexIdsOnly(const VerticesWithDistances& vertices_with_distances); +template +DistanceFunction GetDistanceFunction(DistanceKind distance_kind) { + switch (distance_kind) { + case DistanceKind::kInnerProduct: + return distance::DistanceInnerProduct; + case DistanceKind::kL2Squared: + return distance::DistanceL2Squared; + case DistanceKind::kCosine: + return distance::DistanceCosine; + } + FATAL_INVALID_ENUM_VALUE(DistanceKind, distance_kind); +} } // namespace yb::vectorindex diff --git a/src/yb/vector/graph_repr_defs.h b/src/yb/vector/graph_repr_defs.h index f47615c3c1cb..7b63fd134601 100644 --- a/src/yb/vector/graph_repr_defs.h +++ b/src/yb/vector/graph_repr_defs.h @@ -28,6 +28,9 @@ using VertexId = uint64_t; constexpr VertexId kInvalidVertexId = 0; +template +concept VertexIdCompatible = std::is_unsigned_v && std::is_integral_v && sizeof(T) == 8; + using VectorIndexLevel = uint8_t; using VectorNodeNeighbors = std::set; diff --git a/src/yb/vector/hnsw_options.h b/src/yb/vector/hnsw_options.h index 07206ae0ca4b..1571c3fe0028 100644 --- a/src/yb/vector/hnsw_options.h +++ b/src/yb/vector/hnsw_options.h @@ -61,7 +61,7 @@ struct HNSWOptions { // This is not used by usearch. float robust_prune_alpha = 1.0; - VectorDistanceType distance_type = VectorDistanceType::kL2Squared; + DistanceKind distance_kind = DistanceKind::kL2Squared; std::string ToString() const; }; diff --git a/src/yb/vector/hnsw_util.cc b/src/yb/vector/hnsw_util.cc index db64d6f1816c..b06b55be0367 100644 --- a/src/yb/vector/hnsw_util.cc +++ b/src/yb/vector/hnsw_util.cc @@ -47,8 +47,4 @@ VectorIndexLevel SelectRandomLevel(double ml, VectorIndexLevel max_level) { return narrow_cast(std::min(level, max_level)); } -std::string VertexWithDistance::ToString() const { - return YB_STRUCT_TO_STRING(vertex_id, distance); -} - } // namespace yb::vectorindex diff --git a/src/yb/vector/hnsw_util.h b/src/yb/vector/hnsw_util.h index 5561d24a0c6a..78a34dabb08d 100644 --- a/src/yb/vector/hnsw_util.h +++ b/src/yb/vector/hnsw_util.h @@ -35,24 +35,4 @@ namespace yb::vectorindex { // 3. If ml = 1/log(4) (~0.7213), then p = 1 - 1/4 (~0.75), and the expected level is ~1.333. VectorIndexLevel SelectRandomLevel(double ml, VectorIndexLevel max_level); -template -FloatVector ToFloatVector(const Vector& v) { - FloatVector fv; - fv.reserve(v.size()); - for (auto x : v) { - fv.push_back(static_cast(x)); - } - return fv; -} - -template -std::vector ToFloatVectorOfVectors(const std::vector& v) { - std::vector result; - result.reserve(v.size()); - for (const auto& subvector : v) { - result.push_back(ToFloatVector(subvector)); - } - return result; -} - } // namespace yb::vectorindex diff --git a/src/yb/vector/hnswlib_wrapper.cc b/src/yb/vector/hnswlib_wrapper.cc new file mode 100644 index 000000000000..63795d7129c5 --- /dev/null +++ b/src/yb/vector/hnswlib_wrapper.cc @@ -0,0 +1,137 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/vector/hnswlib_wrapper.h" + +#include + +#pragma GCC diagnostic push + +// For https://gist.githubusercontent.com/mbautin/db70c2fcaa7dd97081b0c909d72a18a8/raw +#pragma GCC diagnostic ignored "-Wunused-function" + +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif + +#include "hnswlib/hnswlib.h" +#include "hnswlib/hnswalg.h" +#include "hnswlib/space_ip.h" +#include "hnswlib/space_l2.h" + +#pragma GCC diagnostic pop + +#include "yb/util/status.h" + +#include "yb/vector/distance.h" + +namespace yb::vectorindex { + +namespace detail { +template +class HnswlibIndexImpl { + public: + using Scalar = typename Vector::value_type; + + using HNSWImpl = typename hnswlib::HierarchicalNSW; + + explicit HnswlibIndexImpl(const HNSWOptions& options) + : options_(options) { + } + + Status Reserve(size_t num_vectors) { + if (hnsw_) { + return STATUS_FORMAT( + IllegalState, "Cannot reserve space for $0 vectors: Hnswlib index already initialized", + num_vectors); + } + RETURN_NOT_OK(CreateSpaceImpl()); + hnsw_ = std::make_unique( + space_.get(), + /* max_elements= */ num_vectors, + /* M= */ options_.max_neighbors_per_vertex, + /* ef_construction= */ options_.ef_construction); + return Status::OK(); + } + + Status Insert(VertexId vertex_id, const Vector& v) { + hnsw_->addPoint(v.data(), vertex_id); + return Status::OK(); + } + + std::vector> Search( + const Vector& query_vector, size_t max_num_results) { + std::vector> result; + auto tmp_result = hnsw_->searchKnnCloserFirst(query_vector.data(), max_num_results); + result.reserve(tmp_result.size()); + for (const auto& entry : tmp_result) { + // Being careful to avoid switching the order of distance and vertex id.. + const auto distance = entry.first; + static_assert(std::is_same_v, DistanceResult>); + + const auto label = entry.second; + static_assert(VertexIdCompatible); + + result.push_back(VertexWithDistance(label, distance)); + } + return result; + } + + Result GetVector(VertexId vertex_id) const { + return STATUS( + NotSupported, "Hnswlib wrapper currently does not allow retriving vectors by id"); + } + + private: + Status CreateSpaceImpl() { + switch (options_.distance_kind) { + case DistanceKind::kL2Squared: { + if constexpr (std::is_same::value) { + space_ = std::make_unique(options_.dimensions); + } else if constexpr (std::is_same>::value) { + space_ = std::make_unique(options_.dimensions); + } else { + return STATUS_FORMAT( + InvalidArgument, + "Unsupported combination of distance type and vector type: $0 and $1", + options_.distance_kind, CoordinateTypeTraits::Kind()); + } + + return Status::OK(); + } + default: + return STATUS_FORMAT( + InvalidArgument, "Unsupported distance type for Hnswlib: $0", + options_.distance_kind); + } + } + + HNSWOptions options_; + std::unique_ptr> space_; + std::unique_ptr hnsw_; +}; + +} // namespace detail + +template +HnswlibIndex::HnswlibIndex(const HNSWOptions& options) + : VectorIndexBase(std::make_unique(options)) { +} + +template +HnswlibIndex::~HnswlibIndex() = default; + +template class HnswlibIndex; +template class HnswlibIndex; + +} // namespace yb::vectorindex diff --git a/src/yb/vector/hnswlib_wrapper.h b/src/yb/vector/hnswlib_wrapper.h new file mode 100644 index 000000000000..5b13dd47ed60 --- /dev/null +++ b/src/yb/vector/hnswlib_wrapper.h @@ -0,0 +1,52 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#pragma once + +#include + +#include "yb/util/result.h" + +#include "yb/vector/hnsw_options.h" +#include "yb/vector/coordinate_types.h" +#include "yb/vector/vector_index_if.h" +#include "yb/vector/vector_index_wrapper_util.h" + +namespace yb::vectorindex { + +namespace detail { +template +class HnswlibIndexImpl; +} // namespace detail + +template +class HnswlibIndex : public VectorIndexBase< + detail::HnswlibIndexImpl, Vector, DistanceResult> { + public: + explicit HnswlibIndex(const HNSWOptions& options); + virtual ~HnswlibIndex(); + private: + using Impl = detail::HnswlibIndexImpl; +}; + +template +class HnswlibIndexFactory : public VectorIndexFactory { + public: + HnswlibIndexFactory() = default; + + std::unique_ptr> Create() const override { + return std::make_unique>(this->hnsw_options_); + } +}; + +} // namespace yb::vectorindex diff --git a/src/yb/vector/sharded_index.h b/src/yb/vector/sharded_index.h new file mode 100644 index 000000000000..76a6a6101e7c --- /dev/null +++ b/src/yb/vector/sharded_index.h @@ -0,0 +1,114 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#pragma once + +#include "yb/vector/coordinate_types.h" +#include "yb/vector/distance.h" +#include "yb/vector/vector_index_if.h" + +namespace yb::vectorindex { + +// Allows creating multiple instances of the vector index so we can saturate the capacity of the +// test system. +template +class ShardedVectorIndex : public VectorIndexIf { + public: + ShardedVectorIndex(const VectorIndexFactory& factory, + size_t num_shards) + : indexes_(num_shards), round_robin_counter_(0) { + for (auto& index : indexes_) { + index = factory.Create(); + } + } + + // Reserve capacity across all shards (each shard gets an equal portion, rounded up). + Status Reserve(size_t num_vectors) override { + size_t capacity_per_shard = (num_vectors + indexes_.size() - 1) / indexes_.size(); // Round up + for (auto& index : indexes_) { + RETURN_NOT_OK(index->Reserve(capacity_per_shard)); + } + return Status::OK(); + } + + // Insert a vector into the current shard using round-robin. + Status Insert(VertexId vertex_id, const Vector& vector) override { + size_t current_index = round_robin_counter_.fetch_add(1) % indexes_.size(); + return indexes_[current_index]->Insert(vertex_id, vector); + } + + // Retrieve a vector from any shard. + Result GetVector(VertexId vertex_id) const override { + for (const auto& index : indexes_) { + auto v = VERIFY_RESULT(index->GetVector(vertex_id)); + if (!v.empty()) { + return v; + } + } + return Vector(); // Return an empty vector if not found. + } + + // Search for the closest vectors across all shards. + std::vector> Search( + const Vector& query_vector, size_t max_num_results) const override { + std::vector> all_results; + for (const auto& index : indexes_) { + auto results = index->Search(query_vector, max_num_results); + all_results.insert(all_results.end(), results.begin(), results.end()); + } + + // Sort all_results by distance and keep the top max_num_results. + std::sort(all_results.begin(), all_results.end(), [](const auto& a, const auto& b) { + return a.distance < b.distance; + }); + + if (all_results.size() > max_num_results) { + all_results.resize(max_num_results); + } + + return all_results; + } + + private: + std::vector>> indexes_; + std::atomic round_robin_counter_; // Atomic counter for thread-safe round-robin insertion +}; + +template +class ShardedVectorIndexFactory : public VectorIndexFactory { + public: + // Constructor to initialize the number of shards and underlying factory. + ShardedVectorIndexFactory( + size_t num_shards, + std::unique_ptr> underlying_factory) + : num_shards_(num_shards), underlying_factory_(std::move(underlying_factory)) {} + + // Override the Create method to produce a ShardedVectorIndex. + std::unique_ptr> Create() const override { + // Create a new ShardedVectorIndex with the specified number of shards. + return std::make_unique>( + *underlying_factory_, num_shards_); + } + + // Override SetOptions to propagate options to the underlying factory. + void SetOptions(const HNSWOptions& options) override { + this->hnsw_options_ = options; // Store the options in this factory. + underlying_factory_->SetOptions(options); // Propagate to the underlying factory. + } + + private: + size_t num_shards_; + std::unique_ptr> underlying_factory_; +}; + +} // namespace yb::vectorindex diff --git a/src/yb/vector/usearch_wrapper.cc b/src/yb/vector/usearch_wrapper.cc index c3dd98b17299..04d4d3df61c6 100644 --- a/src/yb/vector/usearch_wrapper.cc +++ b/src/yb/vector/usearch_wrapper.cc @@ -40,14 +40,16 @@ index_dense_config_t CreateIndexDenseConfig(const HNSWOptions& options) { return config; } -metric_kind_t MetricKindFromDistanceType(VectorDistanceType distance_type) { - switch (distance_type) { - case VectorDistanceType::kL2Squared: +metric_kind_t MetricKindFromDistanceType(DistanceKind distance_kind) { + switch (distance_kind) { + case DistanceKind::kL2Squared: return metric_kind_t::l2sq_k; - case VectorDistanceType::kCosine: + case DistanceKind::kInnerProduct: + return metric_kind_t::ip_k; + case DistanceKind::kCosine: return metric_kind_t::cos_k; } - FATAL_INVALID_ENUM_VALUE(VectorDistanceType, distance_type); + FATAL_INVALID_ENUM_VALUE(DistanceKind, distance_kind); } scalar_kind_t ConvertCoordinateKind(CoordinateKind coordinate_kind) { @@ -62,24 +64,26 @@ scalar_kind_t ConvertCoordinateKind(CoordinateKind coordinate_kind) { FATAL_INVALID_ENUM_VALUE(CoordinateKind, coordinate_kind); } -template -class UsearchIndex::Impl { +namespace detail { + +template +class UsearchIndexImpl { public: - explicit Impl(const HNSWOptions& options) + explicit UsearchIndexImpl(const HNSWOptions& options) : dimensions_(options.dimensions), - distance_type_(options.distance_type), + distance_kind_(options.distance_kind), metric_(dimensions_, - MetricKindFromDistanceType(distance_type_), - ConvertCoordinateKind( - CoordinateTypeTraits::Kind())), + MetricKindFromDistanceType(distance_kind_), + ConvertCoordinateKind(CoordinateTypeTraits::kKind)), index_(decltype(index_)::make( metric_, CreateIndexDenseConfig(options))) { CHECK_GT(dimensions_, 0); } - void Reserve(size_t num_vectors) { + Status Reserve(size_t num_vectors) { index_.reserve(num_vectors); + return Status::OK(); } Status Insert(VertexId vertex_id, const Vector& v) { @@ -89,63 +93,44 @@ class UsearchIndex::Impl { return Status::OK(); } - std::vector Search(const Vector& query_vector, size_t max_num_results) { + std::vector> Search( + const Vector& query_vector, size_t max_num_results) { auto usearch_results = index_.search(query_vector.data(), max_num_results); - std::vector result_vec; + std::vector> result_vec; result_vec.reserve(usearch_results.size()); for (size_t i = 0; i < usearch_results.size(); ++i) { auto match = usearch_results[i]; - result_vec.push_back(VertexWithDistance(match.member.key, match.distance)); + result_vec.push_back(VertexWithDistance(match.member.key, match.distance)); } return result_vec; } - Vector GetVector(VertexId vertex_id) const { + Result GetVector(VertexId vertex_id) const { Vector result; result.resize(dimensions_); if (index_.get(vertex_id, result.data())) { return result; } - return {}; + return Vector(); } private: size_t dimensions_; - VectorDistanceType distance_type_; + DistanceKind distance_kind_; metric_punned_t metric_; index_dense_gt index_; }; -template -UsearchIndex::UsearchIndex(const HNSWOptions& options) - : impl_(std::make_unique(options)) { -} +} // namespace detail -template -UsearchIndex::~UsearchIndex() = default; - -template -void UsearchIndex::Reserve(size_t num_vectors) { - impl_->Reserve(num_vectors); +template +UsearchIndex::UsearchIndex(const HNSWOptions& options) + : VectorIndexBase(std::make_unique(options)) { } -template -Status UsearchIndex::Insert(VertexId vertex_id, const Vector& v) { - return impl_->Insert(vertex_id, v); -} - -template -std::vector UsearchIndex::Search( - const Vector& query_vector, size_t max_num_results) const { - return impl_->Search(query_vector, max_num_results); -} - -template -Vector UsearchIndex::GetVector(VertexId vertex_id) const { - return impl_->GetVector(vertex_id); -} +template +UsearchIndex::~UsearchIndex() = default; -BOOST_PP_SEQ_FOR_EACH( - YB_INSTANTIATE_TEMPLATE_FOR_VECTOR_OF, UsearchIndex, YB_USEARCH_SUPPORTED_COORDINATE_TYPES) +template class UsearchIndex; } // namespace yb::vectorindex diff --git a/src/yb/vector/usearch_wrapper.h b/src/yb/vector/usearch_wrapper.h index cc6f3f4db13a..2422e9a1a971 100644 --- a/src/yb/vector/usearch_wrapper.h +++ b/src/yb/vector/usearch_wrapper.h @@ -16,43 +16,40 @@ #include #include +#include "yb/util/result.h" #include "yb/util/status.h" #include "yb/vector/distance.h" #include "yb/vector/hnsw_options.h" #include "yb/vector/coordinate_types.h" #include "yb/vector/vector_index_if.h" - -// Derived from the available add() overloads in index_dense.hpp. -#define YB_USEARCH_SUPPORTED_COORDINATE_TYPES \ - (float) /* NOLINT */ \ - (double) /* NOLINT */ \ - (int8_t) +#include "yb/vector/vector_index_wrapper_util.h" namespace yb::vectorindex { -template -class UsearchIndex : public VectorIndexReaderIf, public VectorIndexWriterIf { +namespace detail { +template +class UsearchIndexImpl; +} // namespace detail + +template +class UsearchIndex : public VectorIndexBase< + detail::UsearchIndexImpl, Vector, DistanceResult> { public: explicit UsearchIndex(const HNSWOptions& options); virtual ~UsearchIndex(); - - void Reserve(size_t num_vectors) override; - - Status Insert(VertexId vertex_id, const Vector& vector) override; - - std::vector Search( - const Vector& query_vector, size_t max_num_results) const override; - - Vector GetVector(VertexId vertex_id) const override; - private: - class Impl; - - std::unique_ptr impl_; + using Impl = detail::UsearchIndexImpl; }; -// Maps the given coordinate kind to a type supported by the usearch. -CoordinateKind UsearchSupportedCoordinateKind(CoordinateKind coordinate_kind); +template +class UsearchIndexFactory : public VectorIndexFactory { + public: + UsearchIndexFactory() = default; + + std::unique_ptr> Create() const override { + return std::make_unique>(this->hnsw_options_); + } +}; } // namespace yb::vectorindex diff --git a/src/yb/vector/vector_index_if.h b/src/yb/vector/vector_index_if.h index 0cd6d5c989e9..b8c9a731c5b9 100644 --- a/src/yb/vector/vector_index_if.h +++ b/src/yb/vector/vector_index_if.h @@ -13,21 +13,24 @@ // Interface definitions for a vector index. +#pragma once + +#include "yb/util/result.h" + #include "yb/common/vector_types.h" -#include "yb/vector/distance.h" #include "yb/vector/coordinate_types.h" - -#pragma once +#include "yb/vector/distance.h" +#include "yb/vector/hnsw_options.h" namespace yb::vectorindex { -template +template class VectorIndexReaderIf { public: virtual ~VectorIndexReaderIf() = default; - virtual std::vector Search( + virtual std::vector> Search( const Vector& query_vector, size_t max_num_results) const = 0; }; @@ -37,14 +40,36 @@ class VectorIndexWriterIf { virtual ~VectorIndexWriterIf() = default; // Reserves capacity for this number of vectors. - virtual void Reserve(size_t num_vectors) = 0; + virtual Status Reserve(size_t num_vectors) = 0; virtual Status Insert(VertexId vertex_id, const Vector& vector) = 0; - // Returns the vector with the given id, or an empty vector if it does not exist. - virtual Vector GetVector(VertexId vertex_id) const = 0; + // Returns the vector with the given id, an empty vector if such VertexId does not exist, or + // a non-OK status if an error occurred. + virtual Result GetVector(VertexId vertex_id) const = 0; }; -using FloatVectorIndexReader = VectorIndexReaderIf; +template +class VectorIndexIf : public VectorIndexReaderIf, + public VectorIndexWriterIf { + public: + virtual ~VectorIndexIf() = default; +}; + +template +class VectorIndexFactory { + public: + virtual ~VectorIndexFactory() = default; + + virtual std::unique_ptr> Create() const = 0; + + // TODO: generalize this to non-HNSW algorithms/libraries. + virtual void SetOptions(const HNSWOptions& options) { + hnsw_options_ = options; + } + + protected: + HNSWOptions hnsw_options_; +}; } // namespace yb::vectorindex diff --git a/src/yb/vector/vector_index_wrapper_util.h b/src/yb/vector/vector_index_wrapper_util.h new file mode 100644 index 000000000000..9277884c14a1 --- /dev/null +++ b/src/yb/vector/vector_index_wrapper_util.h @@ -0,0 +1,96 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#include "yb/vector/vector_index_if.h" + +#pragma once + +namespace yb::vectorindex { + +// A base class for vector index implementations implementing the pointer-to-implementation idiom. +template +class VectorIndexBase : public VectorIndexIf { + public: + explicit VectorIndexBase(std::unique_ptr impl) + : impl_(std::move(impl)) {} + + ~VectorIndexBase() override = default; + + // Implementations for the VectorIndexReaderIf interface + std::vector> Search( + const Vector& query_vector, size_t max_num_results) const override { + return impl_->Search(query_vector, max_num_results); + } + + // Implementations for the VectorIndexWriterIf interface + Status Reserve(size_t num_vectors) override { + return impl_->Reserve(num_vectors); + } + + Status Insert(VertexId vertex_id, const Vector& vector) override { + return impl_->Insert(vertex_id, vector); + } + + Result GetVector(VertexId vertex_id) const override { + return impl_->GetVector(vertex_id); + } + + protected: + std::unique_ptr impl_; +}; + +// An adapter that allows us to view an index reader with one vector type as an index reader with a +// different vector type. Casts the queries to the vector type supported by the index, and then +// casts the distance type in the results to the distance type expected by the caller. +template< + IndexableVectorType SourceVector, + ValidDistanceResultType SourceDistanceResult, + IndexableVectorType DestinationVector, + ValidDistanceResultType DestinationDistanceResult +> +class VectorIndexReaderAdapter + : public VectorIndexReaderIf { + public: + // Constructor takes the underlying vector index reader + explicit VectorIndexReaderAdapter( + const VectorIndexReaderIf& source_reader) + : source_reader_(source_reader) {} + + // Implementation of the Search function + std::vector> Search( + const DestinationVector& query_vector, size_t max_num_results) const override { + // Cast the query_vector to the SourceVector type + SourceVector cast_query_vector = vector_cast(query_vector); + + // Perform the search using the underlying source_reader + auto source_results = source_reader_.Search(cast_query_vector, max_num_results); + + // Prepare to convert results to the DestinationDistanceResult type + std::vector> destination_results; + destination_results.reserve(source_results.size()); + + for (const auto& source_result : source_results) { + DestinationDistanceResult cast_distance = static_cast( + source_result.distance); + destination_results.emplace_back(source_result.vertex_id, cast_distance); + } + + return destination_results; + } + + private: + const VectorIndexReaderIf& source_reader_; +}; + + +} // namespace yb::vectorindex diff --git a/src/yb/vector/vectorann.cc b/src/yb/vector/vectorann.cc index 6406667d12a8..daee9d11c761 100644 --- a/src/yb/vector/vectorann.cc +++ b/src/yb/vector/vectorann.cc @@ -17,7 +17,9 @@ #include #include "yb/util/memory/arena.h" + #include "yb/vector/vectorann.h" + namespace yb::vectorindex { template @@ -73,9 +75,10 @@ class DummyANN final : public VectorANN { std::vector GetTopKVectors( Vector query_vec, size_t k, double lb_distance, Slice lb_key, bool is_lb_inclusive) const override { + using DistanceResult = double; auto lower_bound = DocKeyWithDistance(lb_key, lb_distance); - auto dist_fn = GetDistanceImpl(VectorDistanceType::kL2Squared); + auto dist_fn = GetDistanceFunction(DistanceKind::kL2Squared); auto modified_dist = [this, &lower_bound, &is_lb_inclusive, dist_fn]( VertexId vertex_id, const Vector& v) -> float { auto& value = values_[vertex_id]; @@ -88,7 +91,8 @@ class DummyANN final : public VectorANN { return dist; }; - auto topk = BruteForcePreciseNearestNeighbors(query_vec, vertex_ids_, modified_dist, k); + auto topk = BruteForcePreciseNearestNeighbors( + query_vec, vertex_ids_, modified_dist, k); std::vector out; for (auto vertex_id : topk) { diff --git a/src/yb/vector/vectorann.h b/src/yb/vector/vectorann.h index 0d5ce8433ce4..45c9cbe43e80 100644 --- a/src/yb/vector/vectorann.h +++ b/src/yb/vector/vectorann.h @@ -14,15 +14,37 @@ #pragma once -#include "yb/common/vector_types.h" -#include "yb/rocksdb/status.h" +#include + +#include "yb/gutil/macros.h" + #include "yb/util/result.h" #include "yb/util/slice.h" + +#include "yb/common/vector_types.h" + #include "yb/vector/coordinate_types.h" #include "yb/vector/vectorann_util.h" +#include "yb/rocksdb/status.h" + namespace yb::vectorindex { +// This MUST match the Vector struct definition in +// src/postgres/third-party-extensions/pgvector/src/vector.h. +struct YSQLVector { + // Commented out as this field is not transferred over the wire for all + // Varlens. + // int32 vl_len_; /* varlena header (do not touch directly!) */ + int16_t dim; /* number of dimensions */ + int16_t unused; + float elems[]; + + private: + DISALLOW_COPY_AND_ASSIGN(YSQLVector); +}; + + // Base class for all ANN vector indexes. // The paging state of an ANN index's iterator must consist of the distance from the query diff --git a/src/yb/vector/vectorann_util.h b/src/yb/vector/vectorann_util.h index 787395611b59..2e433a34d30e 100644 --- a/src/yb/vector/vectorann_util.h +++ b/src/yb/vector/vectorann_util.h @@ -15,21 +15,18 @@ #pragma once #include + #include "yb/common/vector_types.h" + #include "yb/rocksdb/status.h" + #include "yb/util/result.h" #include "yb/util/slice.h" + #include "yb/vector/coordinate_types.h" #include "yb/vector/distance.h" namespace yb::vectorindex { -namespace detail { - -struct CompareDistanceForMinHeap { - bool operator()(const VertexWithDistance& a, const VertexWithDistance& b) { return a > b; } -}; - -} // namespace detail // A simple struct to hold a DocKey that's stored in the value of a vectorann entry and its distance class DocKeyWithDistance { @@ -57,17 +54,18 @@ class DocKeyWithDistance { bool operator>(const DocKeyWithDistance& other) const { return Compare(other) > 0; } }; -using MinDistanceQueue = std::priority_queue< - VertexWithDistance, std::vector, detail::CompareDistanceForMinHeap>; - // Our default comparator for VertexWithDistance already orders the pairs by increasing distance. -using MaxDistanceQueue = std::priority_queue>; +template +using MaxDistanceQueue = + std::priority_queue, + std::vector>>; + // Drain a max-queue of (vertex, distance) pairs and return a list of VertexWithDistance instances // ordered by increasing distance. -inline std::vector DrainMaxQueueToIncreasingDistanceList( - MaxDistanceQueue& queue) { - std::vector result_list; +template +auto DrainMaxQueueToIncreasingDistanceList(MaxDistanceQueue& queue) { + std::vector> result_list; while (!queue.empty()) { result_list.push_back(queue.top()); queue.pop(); @@ -81,14 +79,19 @@ inline std::vector DrainMaxQueueToIncreasingDistanceList( // Computes precise nearest neighbors for the given query by brute force search. In case of // multiple results having the same distance from the query, results with lower vertex ids are // preferred. -template -std::vector BruteForcePreciseNearestNeighbors( - const Vector& query, const std::vector& vertex_ids, - const VertexIdToVectorDistanceFunction& distance_fn, size_t num_results) { - MaxDistanceQueue queue; +template +std::vector> BruteForcePreciseNearestNeighbors( + const Vector& query, + const std::vector& vertex_ids, + const VertexIdToVectorDistanceFunction& distance_fn, + size_t num_results) { + if (num_results == 0) { + return {}; + } + MaxDistanceQueue queue; for (const auto& vertex_id : vertex_ids) { auto distance = distance_fn(vertex_id, query); - auto new_element = VertexWithDistance(vertex_id, distance); + auto new_element = VertexWithDistance(vertex_id, distance); if (queue.size() < num_results || new_element < queue.top()) { // Add a new element if there is a room in the result set, or if the new element is better // than the worst element of the result set. The comparsion is done using the (distance, @@ -105,10 +108,10 @@ std::vector BruteForcePreciseNearestNeighbors( auto result = DrainMaxQueueToIncreasingDistanceList(queue); CHECK_GE(result.size(), std::min(vertex_ids.size(), num_results)) << "Too few records returned by brute-force precise nearest neighbor search on a " - << "dataset with " << vertex_ids.size() - << " vectors. Requested number of results: " << num_results - << ", returned: " << result.size(); + << "dataset with " << vertex_ids.size() << " vectors. Requested number of results: " + << num_results << ", returned: " << result.size(); return result; } + } // namespace yb::vectorindex