From 1afddb3aa8215cd0a4026e8631425f532cf76092 Mon Sep 17 00:00:00 2001 From: Wenxuan Date: Mon, 18 Mar 2024 11:30:03 +0800 Subject: [PATCH] storage: Support vector index and ANN hint (#156) Signed-off-by: Wish --- .gitmodules | 3 + contrib/CMakeLists.txt | 2 + contrib/usearch-cmake/CMakeLists.txt | 11 + dbms/CMakeLists.txt | 1 + dbms/src/Columns/ColumnArray.h | 10 +- dbms/src/Debug/MockStorage.cpp | 4 + dbms/src/Flash/Coprocessor/DAGQueryInfo.h | 4 + .../Coprocessor/DAGStorageInterpreter.cpp | 1 + dbms/src/Flash/Coprocessor/TiDBTableScan.cpp | 3 + dbms/src/Flash/Coprocessor/TiDBTableScan.h | 4 + dbms/src/Interpreters/Context.cpp | 24 + dbms/src/Interpreters/Context.h | 5 + dbms/src/Server/Server.cpp | 5 + .../DeltaMerge/BitmapFilter/BitmapFilter.cpp | 6 +- .../DeltaMerge/BitmapFilter/BitmapFilter.h | 10 +- .../BitmapFilterBlockInputStream.cpp | 77 +- .../BitmapFilterBlockInputStream.h | 9 + .../BitmapFilter/BitmapFilterView.h | 78 + dbms/src/Storages/DeltaMerge/CMakeLists.txt | 1 + dbms/src/Storages/DeltaMerge/ColumnStat.h | 9 + .../Storages/DeltaMerge/DeltaMergeDefines.h | 16 +- dbms/src/Storages/DeltaMerge/File/DMFile.h | 5 + .../File/DMFileBlockInputStream.cpp | 131 +- .../DeltaMerge/File/DMFileBlockInputStream.h | 31 +- .../Storages/DeltaMerge/File/DMFileReader.h | 5 + .../DMFileWithVectorIndexBlockInputStream.h | 592 +++++ ...MFileWithVectorIndexBlockInputStream_fwd.h | 26 + .../Storages/DeltaMerge/File/DMFileWriter.cpp | 55 +- .../Storages/DeltaMerge/File/DMFileWriter.h | 8 +- .../File/VectorColumnFromIndexReader.cpp | 133 + .../File/VectorColumnFromIndexReader.h | 77 + .../File/VectorColumnFromIndexReader_fwd.h | 25 + .../DeltaMerge/File/dtpb/dmfile.proto | 9 + .../Storages/DeltaMerge/Filter/RSOperator.cpp | 6 + .../Storages/DeltaMerge/Filter/RSOperator.h | 4 + .../DeltaMerge/Filter/WithANNQueryInfo.h | 65 + dbms/src/Storages/DeltaMerge/Index/RSIndex.h | 9 +- .../Storages/DeltaMerge/Index/VectorIndex.cpp | 88 + .../Storages/DeltaMerge/Index/VectorIndex.h | 106 + .../Index/VectorIndexHNSW/Index.cpp | 226 ++ .../DeltaMerge/Index/VectorIndexHNSW/Index.h | 67 + .../VectorIndexHNSW/usearch_index_dense.h | 2241 +++++++++++++++++ .../DeltaMerge/Index/VectorIndex_fwd.h | 30 + dbms/src/Storages/DeltaMerge/ReadUtil.cpp | 31 + dbms/src/Storages/DeltaMerge/ReadUtil.h | 12 +- dbms/src/Storages/DeltaMerge/ScanContext.h | 47 + dbms/src/Storages/DeltaMerge/Segment.cpp | 5 +- .../DeltaMerge/SkippableBlockInputStream.h | 8 +- .../Storages/DeltaMerge/StableValueSpace.cpp | 16 +- .../Storages/DeltaMerge/StableValueSpace.h | 3 +- .../tests/gtest_dm_minmax_index.cpp | 1 + .../tests/gtest_dm_storage_delta_merge.cpp | 3 + .../tests/gtest_dm_vector_index.cpp | 1111 ++++++++ .../tests/gtest_segment_test_basic.cpp | 8 +- .../tests/gtest_segment_test_basic.h | 4 + dbms/src/Storages/StorageDeltaMerge.cpp | 25 + .../Storages/StorageDisaggregatedRemote.cpp | 1 + .../Storages/tests/gtest_filter_parser.cpp | 1 + .../tests/gtests_parse_push_down_filter.cpp | 1 + dbms/src/TiDB/Schema/TiDB.cpp | 35 + dbms/src/TiDB/Schema/TiDB.h | 3 + dbms/src/TiDB/Schema/VectorIndex.h | 105 + 62 files changed, 5598 insertions(+), 44 deletions(-) create mode 100644 contrib/usearch-cmake/CMakeLists.txt create mode 100644 dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h create mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp create mode 100644 dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h create mode 100644 dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/Filter/WithANNQueryInfo.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp create mode 100644 dbms/src/TiDB/Schema/VectorIndex.h diff --git a/.gitmodules b/.gitmodules index 021da59cf22..d8452a9e647 100644 --- a/.gitmodules +++ b/.gitmodules @@ -140,3 +140,6 @@ [submodule "contrib/qpl"] path = contrib/qpl url = https://github.com/intel/qpl.git +[submodule "contrib/usearch"] + path = contrib/usearch + url = https://github.com/unum-cloud/usearch.git diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index 0ee205a1455..f99f955f7ca 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -185,3 +185,5 @@ endif () add_subdirectory(magic_enum) add_subdirectory(aws-cmake) + +add_subdirectory(usearch-cmake) diff --git a/contrib/usearch-cmake/CMakeLists.txt b/contrib/usearch-cmake/CMakeLists.txt new file mode 100644 index 00000000000..48501dff779 --- /dev/null +++ b/contrib/usearch-cmake/CMakeLists.txt @@ -0,0 +1,11 @@ +set(USEARCH_PROJECT_DIR "${TiFlash_SOURCE_DIR}/contrib/usearch") +set(USEARCH_SOURCE_DIR "${USEARCH_PROJECT_DIR}/include") + +add_library(_usearch INTERFACE) + +target_include_directories(_usearch SYSTEM INTERFACE + ${USEARCH_PROJECT_DIR}/simsimd/include + ${USEARCH_PROJECT_DIR}/fp16/include + ${USEARCH_SOURCE_DIR}) + +add_library(tiflash_contrib::usearch ALIAS _usearch) diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index a90ee8dbd4d..e43964caebe 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -194,6 +194,7 @@ target_link_libraries (dbms ${OPENSSL_CRYPTO_LIBRARY} ${BTRIE_LIBRARIES} absl::synchronization + tiflash_contrib::usearch tiflash_contrib::aws_s3 etcdpb diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index 4e3ab96eec3..344dd5a7d0e 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -165,16 +165,16 @@ class ColumnArray final : public COWPtrHelper size_t encodeIntoDatumData(size_t element_idx, WriteBuffer & writer) const; -private: - ColumnPtr data; - ColumnPtr offsets; - - size_t ALWAYS_INLINE offsetAt(size_t i) const { return i == 0 ? 0 : getOffsets()[i - 1]; } size_t ALWAYS_INLINE sizeAt(size_t i) const { return i == 0 ? getOffsets()[0] : (getOffsets()[i] - getOffsets()[i - 1]); } +private: + ColumnPtr data; + ColumnPtr offsets; + + size_t ALWAYS_INLINE offsetAt(size_t i) const { return i == 0 ? 0 : getOffsets()[i - 1]; } /// Multiply values if the nested column is ColumnVector. template diff --git a/dbms/src/Debug/MockStorage.cpp b/dbms/src/Debug/MockStorage.cpp index afb46dea2c4..eadcf8dba92 100644 --- a/dbms/src/Debug/MockStorage.cpp +++ b/dbms/src/Debug/MockStorage.cpp @@ -202,6 +202,7 @@ BlockInputStreamPtr MockStorage::getStreamFromDeltaMerge( auto scan_column_infos = mockColumnInfosToTiDBColumnInfos(table_schema_for_delta_merge[table_id]); query_info.dag_query = std::make_unique( filter_conditions->conditions, + tipb::ANNQueryInfo(), empty_pushed_down_filters, // Not care now scan_column_infos, runtime_filter_ids, @@ -230,6 +231,7 @@ BlockInputStreamPtr MockStorage::getStreamFromDeltaMerge( auto scan_column_infos = mockColumnInfosToTiDBColumnInfos(table_schema_for_delta_merge[table_id]); query_info.dag_query = std::make_unique( empty_filters, + tipb::ANNQueryInfo(), empty_pushed_down_filters, // Not care now scan_column_infos, runtime_filter_ids, @@ -261,6 +263,7 @@ void MockStorage::buildExecFromDeltaMerge( auto scan_column_infos = mockColumnInfosToTiDBColumnInfos(table_schema_for_delta_merge[table_id]); query_info.dag_query = std::make_unique( filter_conditions->conditions, + tipb::ANNQueryInfo(), empty_pushed_down_filters, // Not care now scan_column_infos, runtime_filter_ids, @@ -294,6 +297,7 @@ void MockStorage::buildExecFromDeltaMerge( auto scan_column_infos = mockColumnInfosToTiDBColumnInfos(table_schema_for_delta_merge[table_id]); query_info.dag_query = std::make_unique( empty_filters, + tipb::ANNQueryInfo(), empty_pushed_down_filters, // Not care now scan_column_infos, runtime_filter_ids, diff --git a/dbms/src/Flash/Coprocessor/DAGQueryInfo.h b/dbms/src/Flash/Coprocessor/DAGQueryInfo.h index 82e7ba570a1..dade77239e0 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryInfo.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryInfo.h @@ -17,6 +17,7 @@ #include #include #include +#include #include @@ -28,6 +29,7 @@ struct DAGQueryInfo { DAGQueryInfo( const google::protobuf::RepeatedPtrField & filters_, + const tipb::ANNQueryInfo & ann_query_info_, const google::protobuf::RepeatedPtrField & pushed_down_filters_, const ColumnInfos & source_columns_, const std::vector & runtime_filter_ids_, @@ -35,6 +37,7 @@ struct DAGQueryInfo const TimezoneInfo & timezone_info_) : source_columns(source_columns_) , filters(filters_) + , ann_query_info(ann_query_info_) , pushed_down_filters(pushed_down_filters_) , runtime_filter_ids(runtime_filter_ids_) , rf_max_wait_time_ms(rf_max_wait_time_ms_) @@ -44,6 +47,7 @@ struct DAGQueryInfo const ColumnInfos & source_columns; // filters in dag request const google::protobuf::RepeatedPtrField & filters; + const tipb::ANNQueryInfo & ann_query_info; // filters have been push down to storage engine in dag request const google::protobuf::RepeatedPtrField & pushed_down_filters; diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp index 1df0e984968..1d197bceba0 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp @@ -914,6 +914,7 @@ std::unordered_map DAGStorageInterpreter::generateSele query_info.query = dagContext().dummy_ast; query_info.dag_query = std::make_unique( filter_conditions.conditions, + table_scan.getANNQueryInfo(), table_scan.getPushedDownFilters(), table_scan.getColumns(), table_scan.getRuntimeFilterIDs(), diff --git a/dbms/src/Flash/Coprocessor/TiDBTableScan.cpp b/dbms/src/Flash/Coprocessor/TiDBTableScan.cpp index ac590367e9c..8e3c67e34ff 100644 --- a/dbms/src/Flash/Coprocessor/TiDBTableScan.cpp +++ b/dbms/src/Flash/Coprocessor/TiDBTableScan.cpp @@ -30,6 +30,9 @@ TiDBTableScan::TiDBTableScan( , pushed_down_filters( is_partition_table_scan ? std::move(table_scan->partition_table_scan().pushed_down_filter_conditions()) : std::move(table_scan->tbl_scan().pushed_down_filter_conditions())) + , ann_query_info( + is_partition_table_scan ? std::move(table_scan->partition_table_scan().ann_query()) + : std::move(table_scan->tbl_scan().ann_query())) // Only No-partition table need keep order when tablescan executor required keep order. // If keep_order is not set, keep order for safety. , keep_order( diff --git a/dbms/src/Flash/Coprocessor/TiDBTableScan.h b/dbms/src/Flash/Coprocessor/TiDBTableScan.h index 572a52cbaf6..c1390ffa44e 100644 --- a/dbms/src/Flash/Coprocessor/TiDBTableScan.h +++ b/dbms/src/Flash/Coprocessor/TiDBTableScan.h @@ -45,6 +45,8 @@ class TiDBTableScan const google::protobuf::RepeatedPtrField & getPushedDownFilters() const { return pushed_down_filters; } + const tipb::ANNQueryInfo & getANNQueryInfo() const { return ann_query_info; } + private: const tipb::Executor * table_scan; String executor_id; @@ -65,6 +67,8 @@ class TiDBTableScan /// They will be executed on Storage layer. const google::protobuf::RepeatedPtrField pushed_down_filters; + const tipb::ANNQueryInfo ann_query_info; + bool keep_order; bool is_fast_scan; std::vector runtime_filter_ids; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 0512e6a853c..b3465ef0b05 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -57,6 +57,7 @@ #include #include #include +#include #include #include #include @@ -151,6 +152,7 @@ struct ContextShared mutable DBGInvoker dbg_invoker; /// Execute inner functions, debug only. mutable MarkCachePtr mark_cache; /// Cache of marks in compressed files. mutable DM::MinMaxIndexCachePtr minmax_index_cache; /// Cache of minmax index in compressed files. + mutable DM::VectorIndexCachePtr vector_index_cache; mutable DM::DeltaIndexManagerPtr delta_index_manager; /// Manage the Delta Indies of Segments. ProcessList process_list; /// Executing queries at the moment. ViewDependencies view_dependencies; /// Current dependencies @@ -1406,6 +1408,28 @@ void Context::dropMinMaxIndexCache() const shared->minmax_index_cache->reset(); } +void Context::setVectorIndexCache(size_t cache_size_in_bytes) +{ + auto lock = getLock(); + + RUNTIME_CHECK(!shared->vector_index_cache); + + shared->vector_index_cache = std::make_shared(cache_size_in_bytes); +} + +DM::VectorIndexCachePtr Context::getVectorIndexCache() const +{ + auto lock = getLock(); + return shared->vector_index_cache; +} + +void Context::dropVectorIndexCache() const +{ + auto lock = getLock(); + if (shared->vector_index_cache) + shared->vector_index_cache->reset(); +} + bool Context::isDeltaIndexLimited() const { // Don't need to use a lock here, as delta_index_manager should be set at starting up. diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 190906677d1..42b1ee667a4 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -108,6 +108,7 @@ enum class PageStorageRunMode : UInt8; namespace DM { class MinMaxIndexCache; +class VectorIndexCache; class DeltaIndexManager; class GlobalStoragePool; class SharedBlockSchemas; @@ -397,6 +398,10 @@ class Context std::shared_ptr getMinMaxIndexCache() const; void dropMinMaxIndexCache() const; + void setVectorIndexCache(size_t cache_size_in_bytes); + std::shared_ptr getVectorIndexCache() const; + void dropVectorIndexCache() const; + bool isDeltaIndexLimited() const; void setDeltaIndexManager(size_t cache_size_in_bytes); std::shared_ptr getDeltaIndexManager() const; diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index b95e83e9551..2688c11478c 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -1428,6 +1428,11 @@ int Server::main(const std::vector & /*args*/) if (minmax_index_cache_size) global_context->setMinMaxIndexCache(minmax_index_cache_size); + // 1GiB vector index cache. + size_t vec_index_cache_size = config().getUInt64("vec_index_cache_size", 1ULL * 1024 * 1024 * 1024); + if (vec_index_cache_size) + global_context->setVectorIndexCache(vec_index_cache_size); + /// Size of max memory usage of DeltaIndex, used by DeltaMerge engine. /// - In non-disaggregated mode, its default value is 0, means unlimited, and it /// controls the number of total bytes keep in the memory. diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp index eb5e34aa8f2..65bd9867517 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp @@ -73,10 +73,10 @@ void BitmapFilter::set(const UInt32 * data, UInt32 size, const FilterPtr & f) } } -void BitmapFilter::set(UInt32 start, UInt32 limit) +void BitmapFilter::set(UInt32 start, UInt32 limit, bool value) { RUNTIME_CHECK(start + limit <= filter.size(), start, limit, filter.size()); - std::fill(filter.begin() + start, filter.begin() + start + limit, true); + std::fill(filter.begin() + start, filter.begin() + start + limit, value); } bool BitmapFilter::get(IColumn::Filter & f, UInt32 start, UInt32 limit) const @@ -127,4 +127,4 @@ size_t BitmapFilter::count() const { return std::count(filter.cbegin(), filter.cend(), true); } -} // namespace DB::DM \ No newline at end of file +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h index 02ee40b8dc1..d86043b1681 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h @@ -28,9 +28,14 @@ class BitmapFilter void set(BlockInputStreamPtr & stream); void set(const ColumnPtr & col, const FilterPtr & f); void set(const UInt32 * data, UInt32 size, const FilterPtr & f); - void set(UInt32 start, UInt32 limit); + void set(UInt32 start, UInt32 limit, bool value = true); // If return true, all data is match and do not fill the filter. bool get(IColumn::Filter & f, UInt32 start, UInt32 limit) const; + inline bool get(UInt32 n) const + { + RUNTIME_CHECK(n < filter.size(), n, filter.size()); + return filter[n]; + } // filter[start, limit] & f -> f void rangeAnd(IColumn::Filter & f, UInt32 start, UInt32 limit) const; @@ -38,6 +43,7 @@ class BitmapFilter String toDebugString() const; size_t count() const; + inline size_t size() const { return filter.size(); } private: std::vector filter; @@ -45,4 +51,4 @@ class BitmapFilter }; using BitmapFilterPtr = std::shared_ptr; -} // namespace DB::DM \ No newline at end of file +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.cpp index 6cbe450c324..26a9d52fc46 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.cpp +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.cpp @@ -37,7 +37,34 @@ BitmapFilterBlockInputStream::BitmapFilterBlockInputStream( Block BitmapFilterBlockInputStream::readImpl(FilterPtr & res_filter, bool return_filter) { - auto [block, from_delta] = readBlock(stable, delta); + if (return_filter) + return readImpl(res_filter); + + // The caller want a filtered resut, so let's filter by ourselves. + + FilterPtr block_filter; + auto block = readImpl(block_filter); + if (!block) + return {}; + + // all rows in block are not filtered out, simply do nothing. + if (!block_filter) + return block; + + // some rows should be filtered according to `block_filter`: + size_t passed_count = countBytesInFilter(*block_filter); + for (auto & col : block) + { + col.column = col.column->filter(*block_filter, passed_count); + } + return block; +} + +Block BitmapFilterBlockInputStream::readImpl(FilterPtr & res_filter) +{ + FilterPtr block_filter = nullptr; + auto [block, from_delta] = readBlockWithReturnFilter(stable, delta, block_filter); + if (block) { if (from_delta) @@ -45,26 +72,50 @@ Block BitmapFilterBlockInputStream::readImpl(FilterPtr & res_filter, bool return block.setStartOffset(block.startOffset() + stable_rows); } + String block_filter_value; + if (block_filter) + { + for (size_t i = 0; i < block_filter->size(); ++i) + { + block_filter_value += (*block_filter)[i] ? "1" : "0"; + } + } + filter.resize(block.rows()); bool all_match = bitmap_filter->get(filter, block.startOffset(), block.rows()); - if (!all_match) + + if (!block_filter) { - if (return_filter) - { + if (all_match) + res_filter = nullptr; + else res_filter = &filter; + } + else + { + RUNTIME_CHECK(filter.size() >= block_filter->size()); + + if (!all_match) + { + // We have a `block_filter`, and have a bitmap filter in `filter`. + // filter ← filter & block_filter. + std::transform( // + filter.begin(), + filter.end(), + block_filter->begin(), + filter.begin(), + [](UInt8 a, UInt8 b) { return static_cast(a && b); }); } else { - size_t passed_count = countBytesInFilter(filter); - for (auto & col : block) - { - col.column = col.column->filter(filter, passed_count); - } + // We only have a `block_filter`. + // filter ← block_filter. + std::copy( // + block_filter->begin(), + block_filter->end(), + filter.begin()); } - } - else - { - res_filter = nullptr; + res_filter = &filter; } } return block; diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.h b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.h index 999d598fb17..6c83f4ab9b8 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterBlockInputStream.h @@ -45,11 +45,20 @@ class BitmapFilterBlockInputStream : public IProfilingBlockInputStream FilterPtr filter_ignored; return readImpl(filter_ignored, false); } + // When all rows in block are not filtered out, // `res_filter` will be set to null. // The caller needs to do handle this situation. Block readImpl(FilterPtr & res_filter, bool return_filter) override; +private: + // When all rows in block are not filtered out, + // `res_filter` will be set to null. + // The caller needs to do handle this situation. + // This function always returns the filter to the caller. It does not + // filter the block. + Block readImpl(FilterPtr & res_filter); + private: Block header; SkippableBlockInputStreamPtr stable; diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h new file mode 100644 index 00000000000..88765274423 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h @@ -0,0 +1,78 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +// BitmapFilterView provides a subset of a BitmapFilter. +// Accessing BitmapFilterView[i] becomes accessing filter[offset+i]. +class BitmapFilterView +{ +private: + BitmapFilterPtr filter; + UInt32 filter_offset; + UInt32 filter_size; + +public: + explicit BitmapFilterView(const BitmapFilterPtr & filter_, UInt32 offset_, UInt32 size_) + : filter(filter_) + , filter_offset(offset_) + , filter_size(size_) + { + RUNTIME_CHECK(filter_offset + filter_size <= filter->size(), filter_offset, filter_size, filter->size()); + } + + inline bool get(UInt32 n) const + { + RUNTIME_CHECK(n < filter_size); + return filter->get(filter_offset + n); + } + + inline bool operator[](UInt32 n) const { return get(n); } + + inline UInt32 size() const { return filter_size; } + + inline UInt32 offset() const { return filter_offset; } + + String toDebugString() const + { + String s(size(), '1'); + for (UInt32 i = 0; i < size(); i++) + { + if (!get(i)) + { + s[i] = '0'; + } + } + return s; + } + + // Return how many valid rows. + size_t count() const + { + size_t n = 0; + for (UInt32 i = 0; i < size(); i++) + { + if (get(i)) + n++; + } + return n; + } +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/CMakeLists.txt b/dbms/src/Storages/DeltaMerge/CMakeLists.txt index 2f7326f0bae..5bef669c81f 100644 --- a/dbms/src/Storages/DeltaMerge/CMakeLists.txt +++ b/dbms/src/Storages/DeltaMerge/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(./Remote/Proto) add_headers_and_sources(delta_merge .) add_headers_and_sources(delta_merge ./BitmapFilter) add_headers_and_sources(delta_merge ./Index) +add_headers_and_sources(delta_merge ./Index/VectorIndexHNSW) add_headers_and_sources(delta_merge ./Filter) add_headers_and_sources(delta_merge ./FilterParser) add_headers_and_sources(delta_merge ./File) diff --git a/dbms/src/Storages/DeltaMerge/ColumnStat.h b/dbms/src/Storages/DeltaMerge/ColumnStat.h index 0d0e08aad5d..07a251388e7 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnStat.h +++ b/dbms/src/Storages/DeltaMerge/ColumnStat.h @@ -41,6 +41,8 @@ struct ColumnStat size_t array_sizes_bytes = 0; size_t array_sizes_mark_bytes = 0; + std::optional vector_index = std::nullopt; + dtpb::ColumnStat toProto() const { dtpb::ColumnStat stat; @@ -55,6 +57,10 @@ struct ColumnStat stat.set_index_bytes(index_bytes); stat.set_array_sizes_bytes(array_sizes_bytes); stat.set_array_sizes_mark_bytes(array_sizes_mark_bytes); + + if (vector_index.has_value()) + stat.mutable_vector_index()->CopyFrom(vector_index.value()); + return stat; } @@ -71,6 +77,9 @@ struct ColumnStat index_bytes = proto.index_bytes(); array_sizes_bytes = proto.array_sizes_bytes(); array_sizes_mark_bytes = proto.array_sizes_mark_bytes(); + + if (proto.has_vector_index()) + vector_index = proto.vector_index(); } // @deprecated. New fields should be added via protobuf. Use `toProto` instead diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h b/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h index 47c15b611d7..639351bc33e 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -89,11 +90,22 @@ struct ColumnDefine DataTypePtr type; Field default_value; - explicit ColumnDefine(ColId id_ = 0, String name_ = "", DataTypePtr type_ = nullptr, Field default_value_ = Field{}) + /// Note: ColumnDefine is used in both Write path and Read path. + /// In the read path, vector_index is usually not available. Use AnnQueryInfo for + /// read related vector index information. + TiDB::VectorIndexInfoPtr vector_index; + + explicit ColumnDefine( + ColId id_ = 0, + String name_ = "", + DataTypePtr type_ = nullptr, + Field default_value_ = Field{}, + TiDB::VectorIndexInfoPtr vector_index_ = nullptr) : id(id_) , name(std::move(name_)) , type(std::move(type_)) , default_value(std::move(default_value_)) + , vector_index(vector_index_) {} }; @@ -185,4 +197,4 @@ struct fmt::formatter { return fmt::format_to(ctx.out(), "{}/{}", cd.id, cd.type->getName()); } -}; \ No newline at end of file +}; diff --git a/dbms/src/Storages/DeltaMerge/File/DMFile.h b/dbms/src/Storages/DeltaMerge/File/DMFile.h index c6db708c98a..efa719478da 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFile.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFile.h @@ -44,11 +44,16 @@ namespace DB { namespace DM { + +class DMFileWithVectorIndexBlockInputStream; + using DMFilePtr = std::shared_ptr; using DMFiles = std::vector; class DMFile : private boost::noncopyable { + friend class DMFileWithVectorIndexBlockInputStream; + public: enum Status : int { diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp index 917e4eea3d6..0d216ffe9e2 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp @@ -14,6 +14,9 @@ #include #include +#include +#include +#include #include namespace DB::DM @@ -24,7 +27,10 @@ DMFileBlockInputStreamBuilder::DMFileBlockInputStreamBuilder(const Context & con { // init from global context const auto & global_context = context.getGlobalContext(); - setCaches(global_context.getMarkCache(), global_context.getMinMaxIndexCache()); + setCaches( + global_context.getMarkCache(), + global_context.getMinMaxIndexCache(), + global_context.getVectorIndexCache()); // init from settings setFromSettings(context.getSettingsRef()); } @@ -94,4 +100,127 @@ DMFileBlockInputStreamPtr DMFileBlockInputStreamBuilder::build( return std::make_shared(std::move(reader), max_sharing_column_bytes_for_all > 0); } + +SkippableBlockInputStreamPtr DMFileBlockInputStreamBuilder::build2( + const DMFilePtr & dmfile, + const ColumnDefines & read_columns, + const RowKeyRanges & rowkey_ranges, + const ScanContextPtr & scan_context) +{ + auto fallback = [&]() { + return build(dmfile, read_columns, rowkey_ranges, scan_context); + }; + + if (rs_filter == nullptr) + return fallback(); + + // Fast Scan and Clean Read does not affect our behavior. (TODO: Confirm?) + // if (is_fast_scan || enable_del_clean_read || enable_handle_clean_read) + // return fallback(); + + auto filter_with_ann = std::dynamic_pointer_cast(rs_filter); + if (filter_with_ann == nullptr) + return fallback(); + + auto ann_query_info = filter_with_ann->ann_query_info; + if (!ann_query_info) + return fallback(); + + if (!bitmap_filter.has_value()) + return fallback(); + + // Fast check: ANNQueryInfo is available in the whole read path. However we may not reading vector column now. + bool is_matching_ann_query = false; + for (const auto & cd : read_columns) + { + if (cd.id == ann_query_info->column_id()) + { + is_matching_ann_query = true; + break; + } + } + if (!is_matching_ann_query) + return fallback(); + + Block header_layout = toEmptyBlock(read_columns); + + // Copy out the vector column for later use. Copy is intentionally performed after the + // fast check so that in fallback conditions we don't need unnecessary copies. + std::optional vec_column; + ColumnDefines rest_columns{}; + for (const auto & cd : read_columns) + { + if (cd.id == ann_query_info->column_id()) + vec_column.emplace(cd); + else + rest_columns.emplace_back(cd); + } + + // No vector index column is specified, just use the normal logic. + if (!vec_column.has_value()) + return fallback(); + + RUNTIME_CHECK(rest_columns.size() + 1 == read_columns.size(), rest_columns.size(), read_columns.size()); + + const auto & vec_column_stat = dmfile->getColumnStat(vec_column->id); + if (vec_column_stat.index_bytes == 0 || !vec_column_stat.vector_index.has_value()) + // Vector index is defined but does not exist on the data file, + // or there is no data at all + return fallback(); + + // All check passed. Let's read via vector index. + + DMFilePackFilter pack_filter = DMFilePackFilter::loadFrom( + dmfile, + index_cache, + /*set_cache_if_miss*/ true, + rowkey_ranges, + rs_filter, + read_packs, + file_provider, + read_limiter, + scan_context, + tracing_id); + + bool enable_read_thread = SegmentReaderPoolManager::instance().isSegmentReader(); + bool is_common_handle = !rowkey_ranges.empty() && rowkey_ranges[0].is_common_handle; + + DMFileReader rest_columns_reader( + dmfile, + rest_columns, + is_common_handle, + enable_handle_clean_read, + enable_del_clean_read, + is_fast_scan, + max_data_version, + std::move(pack_filter), + mark_cache, + enable_column_cache, + column_cache, + aio_threshold, + max_read_buffer_size, + file_provider, + read_limiter, + rows_threshold_per_read, + read_one_pack_every_time, + tracing_id, + enable_read_thread, + scan_context); + + DMFileWithVectorIndexBlockInputStreamPtr reader = DMFileWithVectorIndexBlockInputStream::create( + ann_query_info, + dmfile, + std::move(header_layout), + std::move(rest_columns_reader), + std::move(vec_column.value()), + file_provider, + read_limiter, + scan_context, + vector_index_cache, + bitmap_filter.value(), + tracing_id); + + return reader; +} + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h index e68caf94ffe..5f1c30b6ad9 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h @@ -14,9 +14,12 @@ #pragma once +#include #include #include #include +#include +#include #include #include #include @@ -91,6 +94,20 @@ class DMFileBlockInputStreamBuilder const RowKeyRanges & rowkey_ranges, const ScanContextPtr & scan_context); + // Build the final stream ptr. The return value could be DMFileBlockInputStreamPtr or DMFileWithVectorIndexBlockInputStream. + // Empty `rowkey_ranges` means not filter by rowkey + // Should not use the builder again after `build` is called. + // In the following conditions DMFileWithVectorIndexBlockInputStream will be returned: + // 1. BitmapFilter is provided + // 2. ANNQueryInfo is available in the RSFilter + // 3. The vector column mentioned by ANNQueryInfo is in the read_columns + // 4. The vector column mentioned by ANNQueryInfo exists vector index file + SkippableBlockInputStreamPtr build2( + const DMFilePtr & dmfile, + const ColumnDefines & read_columns, + const RowKeyRanges & rowkey_ranges, + const ScanContextPtr & scan_context); + // **** filters **** // // Only set enable_handle_clean_read_ param to true when @@ -115,6 +132,12 @@ class DMFileBlockInputStreamBuilder return *this; } + DMFileBlockInputStreamBuilder & setBitmapFilter(const BitmapFilterView & bitmap_filter_) + { + bitmap_filter.emplace(bitmap_filter_); + return *this; + } + DMFileBlockInputStreamBuilder & setRSOperator(const RSOperatorPtr & filter_) { rs_filter = filter_; @@ -170,10 +193,12 @@ class DMFileBlockInputStreamBuilder } DMFileBlockInputStreamBuilder & setCaches( const MarkCachePtr & mark_cache_, - const MinMaxIndexCachePtr & index_cache_) + const MinMaxIndexCachePtr & index_cache_, + const VectorIndexCachePtr & vector_index_cache_) { mark_cache = mark_cache_; index_cache = index_cache_; + vector_index_cache = vector_index_cache_; return *this; } @@ -203,6 +228,10 @@ class DMFileBlockInputStreamBuilder size_t max_sharing_column_bytes_for_all = 0; String tracing_id; ReadTag read_tag = ReadTag::Internal; + + VectorIndexCachePtr vector_index_cache; + // Note: Currently thie field is assigned only for Stable streams, not available for ColumnFileBig + std::optional bitmap_filter; }; /** diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileReader.h b/dbms/src/Storages/DeltaMerge/File/DMFileReader.h index 6df002044ac..48d52016946 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileReader.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileReader.h @@ -31,6 +31,9 @@ namespace DB { namespace DM { + +class DMFileWithVectorIndexBlockInputStream; + class RSOperator; using RSOperatorPtr = std::shared_ptr; @@ -38,6 +41,8 @@ inline static const size_t DMFILE_READ_ROWS_THRESHOLD = DEFAULT_MERGE_BLOCK_SIZE class DMFileReader { + friend class DMFileWithVectorIndexBlockInputStream; + public: static bool isCacheableColumn(const ColumnDefine & cd); // Read stream for single column diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h new file mode 100644 index 00000000000..20aa03110d7 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h @@ -0,0 +1,592 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +/** + * @brief DMFileWithVectorIndexBlockInputStream is similar to DMFileBlockInputStream. + * However it can read data efficiently with the help of vector index. + * + * General steps: + * 1. Read all PK, Version and Del Marks (respecting Pack filters). + * 2. Construct a bitmap of valid rows (in memory). This bitmap guides the reading of vector index to determine whether a row is valid or not. + * + * Note: Step 1 and 2 simply rely on the BitmapFilter to avoid repeat IOs. + * BitmapFilter is global, which provides row valid info for all DMFile + Delta. + * What we need is which rows are valid in THIS DMFile. + * To transform a global BitmapFilter result into a local one, RowOffsetTracker is used. + * + * 3. Perform a vector search for Top K vector rows. We now have K row_ids whose vector distance is close. + * 4. Map these row_ids to packids as the new pack filter. + * 5. Read from other columns with the new pack filter. + * For each read, join other columns and vector column together. + * + * Step 3~4 is performed lazily at first read. + * + * Before constructing this class, the caller must ensure that vector index + * exists on the corresponding column. If the index does not exist, the caller + * should use the standard DMFileBlockInputStream. + */ +class DMFileWithVectorIndexBlockInputStream : public SkippableBlockInputStream +{ +public: + static DMFileWithVectorIndexBlockInputStreamPtr create( + const ANNQueryInfoPtr & ann_query_info, + const DMFilePtr & dmfile, + Block && header_layout, + DMFileReader && reader, + ColumnDefine && vec_cd, + const FileProviderPtr & file_provider, + const ReadLimiterPtr & read_limiter, + const ScanContextPtr & scan_context, + const VectorIndexCachePtr & vec_index_cache, + const BitmapFilterView & valid_rows, + const String & tracing_id) + { + return std::make_shared( + ann_query_info, + dmfile, + std::move(header_layout), + std::move(reader), + std::move(vec_cd), + file_provider, + read_limiter, + scan_context, + vec_index_cache, + valid_rows, + tracing_id); + } + + explicit DMFileWithVectorIndexBlockInputStream( + const ANNQueryInfoPtr & ann_query_info_, + const DMFilePtr & dmfile_, + Block && header_layout_, + DMFileReader && reader_, + ColumnDefine && vec_cd_, + const FileProviderPtr & file_provider_, + const ReadLimiterPtr & read_limiter_, + const ScanContextPtr & scan_context_, + const VectorIndexCachePtr & vec_index_cache_, + const BitmapFilterView & valid_rows_, + const String & tracing_id) + : log(Logger::get(tracing_id)) + , ann_query_info(ann_query_info_) + , dmfile(dmfile_) + , header_layout(std::move(header_layout_)) + , reader(std::move(reader_)) + , vec_cd(std::move(vec_cd_)) + , file_provider(file_provider_) + , read_limiter(read_limiter_) + , scan_context(scan_context_) + , vec_index_cache(vec_index_cache_) + , valid_rows(valid_rows_) + { + RUNTIME_CHECK(ann_query_info); + RUNTIME_CHECK(vec_cd.id == ann_query_info->column_id()); + for (const auto & cd : reader.read_columns) + { + RUNTIME_CHECK(header_layout.has(cd.name), cd.name); + RUNTIME_CHECK(cd.id != vec_cd.id); + } + RUNTIME_CHECK(header_layout.has(vec_cd.name)); + RUNTIME_CHECK(header_layout.columns() == reader.read_columns.size() + 1); + + // Fill start_offset_to_pack_id + const auto & pack_stats = dmfile->getPackStats(); + start_offset_to_pack_id.reserve(pack_stats.size()); + UInt32 start_offset = 0; + for (size_t pack_id = 0, pack_id_max = pack_stats.size(); pack_id < pack_id_max; pack_id++) + { + start_offset_to_pack_id[start_offset] = pack_id; + start_offset += pack_stats[pack_id].rows; + } + + // Fill header + header = toEmptyBlock(reader.read_columns); + addColumnToBlock( + header, + vec_cd.id, + vec_cd.name, + vec_cd.type, + vec_cd.type->createColumn(), + vec_cd.default_value); + } + + ~DMFileWithVectorIndexBlockInputStream() override + { + if (!vec_column_reader) + return; + + scan_context->total_vector_idx_read_vec_time_ms + += static_cast(duration_read_from_vec_index_seconds * 1000); + scan_context->total_vector_idx_read_others_time_ms + += static_cast(duration_read_from_other_columns_seconds * 1000); + + LOG_DEBUG( // + log, + "Finished read DMFile with vector index for column dmf_{}/{}(id={}), " + "query_top_k={} load_index+result={:.2f}s read_from_index={:.2f}s read_from_others={:.2f}s", + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + ann_query_info->top_k(), + duration_load_vec_index_and_results_seconds, + duration_read_from_vec_index_seconds, + duration_read_from_other_columns_seconds); + } + +public: + Block read() override + { + FilterPtr filter = nullptr; + return read(filter, false); + } + + // When all rows in block are not filtered out, + // `res_filter` will be set to null. + // The caller needs to do handle this situation. + Block read(FilterPtr & res_filter, bool return_filter) override + { + if (return_filter) + return readImpl(res_filter); + + // If return_filter == false, we must filter by ourselves. + + FilterPtr filter = nullptr; + auto res = readImpl(filter); + if (filter != nullptr) + { + for (auto & col : res) + col.column = col.column->filter(*filter, -1); + } + // filter == nullptr means all rows are valid and no need to filter. + + return res; + } + + // When all rows in block are not filtered out, + // `res_filter` will be set to null. + // The caller needs to do handle this situation. + Block readImpl(FilterPtr & res_filter) + { + load(); + + Block res; + if (!reader.read_columns.empty()) + res = readByFollowingOtherColumns(); + else + res = readByIndexReader(); + + // Filter the output rows. If no rows need to filter, res_filter is nullptr. + filter.resize(res.rows()); + bool all_match = valid_rows_after_search.get(filter, res.startOffset(), res.rows()); + + if unlikely (all_match) + res_filter = nullptr; + else + res_filter = &filter; + return res; + } + + bool getSkippedRows(size_t &) override + { + RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support getSkippedRows"); + } + + size_t skipNextBlock() override + { + RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support skipNextBlock"); + } + + Block readWithFilter(const IColumn::Filter &) override + { + // We don't support the normal late materialization, because + // we are already doing it. + RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support late materialization"); + } + + String getName() const override { return "DMFileWithVectorIndex"; } + + Block getHeader() const override { return header; } + +private: + // Only used in readByIndexReader() + size_t index_reader_next_pack_id = 0; + // Only used in readByIndexReader() + size_t index_reader_next_row_id = 0; + + // Read data totally from the VectorColumnFromIndexReader. This is used + // when there is no other column to read. + Block readByIndexReader() + { + const auto & pack_stats = dmfile->getPackStats(); + size_t all_packs = pack_stats.size(); + const auto & use_packs = reader.pack_filter.getUsePacksConst(); + + RUNTIME_CHECK(use_packs.size() == all_packs); + + // Skip as many packs as possible according to Pack Filter + while (index_reader_next_pack_id < all_packs) + { + if (use_packs[index_reader_next_pack_id]) + break; + index_reader_next_row_id += pack_stats[index_reader_next_pack_id].rows; + index_reader_next_pack_id++; + } + + if (index_reader_next_pack_id >= all_packs) + // Finished + return {}; + + auto read_pack_id = index_reader_next_pack_id; + auto block_start_row_id = index_reader_next_row_id; + auto read_rows = pack_stats[read_pack_id].rows; + + index_reader_next_pack_id++; + index_reader_next_row_id += read_rows; + + Block block; + block.setStartOffset(block_start_row_id); + + auto vec_column = vec_cd.type->createColumn(); + + Stopwatch w; + vec_column_reader->read(vec_column, read_pack_id, read_rows); + duration_read_from_vec_index_seconds += w.elapsedSeconds(); + + block.insert(ColumnWithTypeAndName{// + std::move(vec_column), + vec_cd.type, + vec_cd.name, + vec_cd.id}); + + return block; + } + + // Read data from other columns first, then read from VectorColumnFromIndexReader. This is used + // when there are other columns to read. + Block readByFollowingOtherColumns() + { + // First read other columns. + Stopwatch w; + auto block_others = reader.read(); + duration_read_from_other_columns_seconds += w.elapsedSeconds(); + + if (!block_others) + return {}; + + // Using vec_cd.type to construct a Column directly instead of using + // the type from dmfile, so that we don't need extra transforms + // (e.g. wrap with a Nullable). vec_column_reader is compatible with + // both Nullable and NotNullable. + auto vec_column = vec_cd.type->createColumn(); + + // Then read from vector index for the same pack. + w.restart(); + + vec_column_reader->read(vec_column, getPackIdFromBlock(block_others), block_others.rows()); + duration_read_from_vec_index_seconds += w.elapsedSeconds(); + + // Re-assemble block using the same layout as header_layout. + Block res = header_layout.cloneEmpty(); + // Note: the start offset counts from the beginning of THIS dmfile. It + // is not a global offset. + res.setStartOffset(block_others.startOffset()); + for (const auto & elem : block_others) + { + RUNTIME_CHECK(res.has(elem.name)); + res.getByName(elem.name).column = std::move(elem.column); + } + RUNTIME_CHECK(res.has(vec_cd.name)); + res.getByName(vec_cd.name).column = std::move(vec_column); + + return res; + } + +private: + void load() + { + if (loaded) + return; + + Stopwatch w; + + loadVectorIndex(); + loadVectorSearchResult(); + + duration_load_vec_index_and_results_seconds = w.elapsedSeconds(); + + loaded = true; + } + + void loadVectorIndex() + { + bool is_index_load_from_cache = true; + + auto col_id = ann_query_info->column_id(); + + RUNTIME_CHECK(dmfile->useMetaV2()); // v3 + + // Check vector index exists on the column + const auto & column_stat = dmfile->getColumnStat(col_id); + RUNTIME_CHECK(column_stat.index_bytes > 0); + + const auto & type = column_stat.type; + RUNTIME_CHECK(VectorIndex::isSupportedType(*type)); + RUNTIME_CHECK(column_stat.vector_index.has_value()); + + const auto file_name_base = DMFile::getFileNameBase(col_id); + auto load_vector_index = [&]() { + is_index_load_from_cache = false; + + auto index_guard = S3::S3RandomAccessFile::setReadFileInfo( + {dmfile->getReadFileSize(col_id, dmfile->colIndexFileName(file_name_base)), scan_context}); + + auto info = dmfile->merged_sub_file_infos.find(dmfile->colIndexFileName(file_name_base)); + if (info == dmfile->merged_sub_file_infos.end()) + { + throw Exception( + fmt::format("Unknown index file {}", dmfile->colIndexPath(file_name_base)), + ErrorCodes::LOGICAL_ERROR); + } + + auto file_path = dmfile->mergedPath(info->second.number); + auto encryp_path = dmfile->encryptionMergedPath(info->second.number); + auto offset = info->second.offset; + auto data_size = info->second.size; + + auto buffer = ReadBufferFromFileProvider( + file_provider, + file_path, + encryp_path, + dmfile->getConfiguration()->getChecksumFrameLength(), + read_limiter); + buffer.seek(offset); + + // TODO: Read from file directly? + String raw_data; + raw_data.resize(data_size); + buffer.read(reinterpret_cast(raw_data.data()), data_size); + + auto buf = createReadBufferFromData( + std::move(raw_data), + dmfile->colDataPath(file_name_base), + dmfile->getConfiguration()->getChecksumFrameLength(), + dmfile->configuration->getChecksumAlgorithm(), + dmfile->configuration->getChecksumFrameLength()); + + auto index_kind = magic_enum::enum_cast(column_stat.vector_index->index_kind()); + RUNTIME_CHECK(index_kind.has_value()); + RUNTIME_CHECK(index_kind.value() != TiDB::VectorIndexKind::INVALID); + + auto index_distance_metric + = magic_enum::enum_cast(column_stat.vector_index->distance_metric()); + RUNTIME_CHECK(index_distance_metric.has_value()); + RUNTIME_CHECK(index_distance_metric.value() != TiDB::DistanceMetric::INVALID); + + auto index = VectorIndex::load(index_kind.value(), index_distance_metric.value(), *buf); + return index; + }; + + Stopwatch watch; + + if (vec_index_cache) + { + // TODO: Is cache key valid on Compute Node for different Write Nodes? + vec_index = vec_index_cache->getOrSet(dmfile->colIndexCacheKey(file_name_base), load_vector_index); + } + else + { + // try load from the cache first + if (vec_index_cache) + vec_index = vec_index_cache->get(dmfile->colIndexCacheKey(file_name_base)); + if (vec_index == nullptr) + vec_index = load_vector_index(); + } + + double duration_load_index = watch.elapsedSeconds(); + RUNTIME_CHECK(vec_index != nullptr); + scan_context->total_vector_idx_load_time_ms += static_cast(duration_load_index * 1000); + if (is_index_load_from_cache) + scan_context->total_vector_idx_load_from_cache++; + else + scan_context->total_vector_idx_load_from_disk++; + + LOG_DEBUG( // + log, + "Loaded vector index for column dmf_{}/{}(id={}), index_size={} kind={} cost={:.2f}s from_cache={}", + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + column_stat.index_bytes, + column_stat.vector_index->index_kind(), + duration_load_index, + is_index_load_from_cache); + } + + void loadVectorSearchResult() + { + Stopwatch watch; + + VectorIndex::SearchStatistics statistics; + auto results_rowid = vec_index->search(ann_query_info, valid_rows, statistics); + + double search_duration = watch.elapsedSeconds(); + scan_context->total_vector_idx_search_time_ms += static_cast(search_duration * 1000); + scan_context->total_vector_idx_search_discarded_nodes += statistics.discarded_nodes; + scan_context->total_vector_idx_search_visited_nodes += statistics.visited_nodes; + + size_t rows_after_mvcc = valid_rows.count(); + size_t rows_after_vector_search = results_rowid.size(); + + // After searching with the BitmapFilter, we create a bitmap + // to exclude rows that are not in the search result, because these rows + // are produced as [] or NULL, which is not a valid vector for future use. + // The bitmap will be used when returning the output to the caller. + { + valid_rows_after_search = BitmapFilter(valid_rows.size(), false); + for (auto rowid : results_rowid) + valid_rows_after_search.set(rowid, 1, true); + valid_rows_after_search.runOptimize(); + } + + vec_column_reader = std::make_shared( // + dmfile, + vec_index, + std::move(results_rowid)); + + // Vector index is very likely to filter out some packs. For example, + // if we query for Top 1, then only 1 pack will be remained. So we + // update the pack filter used by the DMFileReader to avoid reading + // unnecessary data for other columns. + size_t valid_packs_before_search = 0; + size_t valid_packs_after_search = 0; + const auto & pack_stats = dmfile->getPackStats(); + auto & use_packs = reader.pack_filter.getUsePacks(); + + size_t results_it = 0; + const size_t results_it_max = results_rowid.size(); + + UInt32 pack_start = 0; + + for (size_t pack_id = 0, pack_id_max = dmfile->getPacks(); pack_id < pack_id_max; pack_id++) + { + if (use_packs[pack_id]) + valid_packs_before_search++; + + bool pack_has_result = false; + + UInt32 pack_end = pack_start + pack_stats[pack_id].rows; + while (results_it < results_it_max // + && results_rowid[results_it] >= pack_start // + && results_rowid[results_it] < pack_end) + { + pack_has_result = true; + results_it++; + } + + if (!pack_has_result) + use_packs[pack_id] = 0; + + if (use_packs[pack_id]) + valid_packs_after_search++; + + pack_start = pack_end; + } + + RUNTIME_CHECK_MSG(results_it == results_it_max, "All packs has been visited but not all results are consumed"); + + LOG_DEBUG( // + log, + "Finished vector search over column dmf_{}/{}(id={}), cost={:.2f}s " + "top_k_[query/visited/discarded/result]={}/{}/{}/{} " + "rows_[file/after_mvcc/after_search]={}/{}/{} " + "pack_[total/before_search/after_search]={}/{}/{}", + + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + search_duration, + + ann_query_info->top_k(), + statistics.visited_nodes, // Visited nodes will be larger than query_top_k when there are MVCC rows + statistics.discarded_nodes, // How many nodes are skipped by MVCC + results_rowid.size(), + + dmfile->getRows(), + rows_after_mvcc, + rows_after_vector_search, + + pack_stats.size(), + valid_packs_before_search, + valid_packs_after_search); + } + + inline UInt32 getPackIdFromBlock(const Block & block) + { + // The start offset of a block is ensured to be aligned with the pack. + // This is how we know which pack the block comes from. + auto start_offset = block.startOffset(); + auto it = start_offset_to_pack_id.find(start_offset); + RUNTIME_CHECK(it != start_offset_to_pack_id.end()); + return it->second; + } + +private: + const LoggerPtr log; + + const ANNQueryInfoPtr ann_query_info; + const DMFilePtr dmfile; + + // The header_layout should contain columns from reader and vec_cd + Block header_layout; + // Vector column should be excluded in the reader + DMFileReader reader; + // Note: ColumnDefine comes from read path does not have vector_index fields. + const ColumnDefine vec_cd; + const FileProviderPtr file_provider; + const ReadLimiterPtr read_limiter; + const ScanContextPtr scan_context; + const VectorIndexCachePtr vec_index_cache; + const BitmapFilterView valid_rows; // TODO: Currently this does not support ColumnFileBig + + Block header; // Filled in constructor; + + std::unordered_map start_offset_to_pack_id; // Filled from reader in constructor + + // Set after load(). + VectorIndexPtr vec_index = nullptr; + // Set after load(). + VectorColumnFromIndexReaderPtr vec_column_reader = nullptr; + // Set after load(). Used to filter the output rows. + BitmapFilter valid_rows_after_search{0, false}; + IColumn::Filter filter{}; + + bool loaded = false; + + double duration_load_vec_index_and_results_seconds = 0; + double duration_read_from_vec_index_seconds = 0; + double duration_read_from_other_columns_seconds = 0; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h new file mode 100644 index 00000000000..6e88a873070 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h @@ -0,0 +1,26 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class DMFileWithVectorIndexBlockInputStream; + +using DMFileWithVectorIndexBlockInputStreamPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp index e910b50d751..278e814d88f 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #ifndef NDEBUG @@ -58,11 +59,14 @@ DMFileWriter::DMFileWriter( for (auto & cd : write_columns) { + if (cd.vector_index) + RUNTIME_CHECK(VectorIndex::isSupportedType(*cd.type)); + // TODO: currently we only generate index for Integers, Date, DateTime types, and this should be configurable by user. /// for handle column always generate index auto type = removeNullable(cd.type); bool do_index = cd.id == EXTRA_HANDLE_COLUMN_ID || type->isInteger() || type->isDateOrDateTime(); - addStreams(cd.id, cd.type, do_index); + addStreams(cd.id, cd.type, do_index, cd.vector_index); dmfile->column_stats.emplace(cd.id, ColumnStat{cd.id, cd.type, /*avg_size=*/0}); } } @@ -111,12 +115,11 @@ DMFileWriter::WriteBufferFromFileBasePtr DMFileWriter::createPackStatsFile() options.max_compress_block_size); } -void DMFileWriter::addStreams(ColId col_id, DataTypePtr type, bool do_index) +void DMFileWriter::addStreams(ColId col_id, DataTypePtr type, bool do_index, TiDB::VectorIndexInfoPtr do_vector_index) { auto callback = [&](const IDataType::SubstreamPath & substream_path) { const auto stream_name = DMFile::getFileNameBase(col_id, substream_path); - bool substream_do_index - = do_index && !IDataType::isNullMap(substream_path) && !IDataType::isArraySizes(substream_path); + bool substream_can_index = !IDataType::isNullMap(substream_path) && !IDataType::isArraySizes(substream_path); auto stream = std::make_unique( dmfile, stream_name, @@ -125,7 +128,8 @@ void DMFileWriter::addStreams(ColId col_id, DataTypePtr type, bool do_index) options.max_compress_block_size, file_provider, write_limiter, - substream_do_index); + do_index && substream_can_index, + (do_vector_index && substream_can_index) ? do_vector_index : nullptr); column_streams.emplace(stream_name, std::move(stream)); }; @@ -230,6 +234,9 @@ void DMFileWriter::writeColumn( (col_id == EXTRA_HANDLE_COLUMN_ID || col_id == TAG_COLUMN_ID) ? nullptr : del_mark); } + if (stream->vector_index) + stream->vector_index->addBlock(column, del_mark); + /// There could already be enough data to compress into the new block. if (stream->compressed_buf->offset() >= options.min_compress_block_size) stream->compressed_buf->next(); @@ -289,7 +296,6 @@ void DMFileWriter::finalizeColumn(ColId col_id, DataTypePtr type) } }; #endif - auto callback = [&](const IDataType::SubstreamPath & substream) { const auto stream_name = DMFile::getFileNameBase(col_id, substream); auto & stream = column_streams.at(stream_name); @@ -350,6 +356,38 @@ void DMFileWriter::finalizeColumn(ColId col_id, DataTypePtr type) buffer->next(); } + if (stream->vector_index && !is_empty_file) + { + dmfile->checkMergedFile(merged_file, file_provider, write_limiter); + + auto fname = dmfile->colIndexFileName(stream_name); + + auto buffer = createWriteBufferFromFileBaseByWriterBuffer( + merged_file.buffer, + dmfile->configuration->getChecksumAlgorithm(), + dmfile->configuration->getChecksumFrameLength()); + + stream->vector_index->serializeBinary(*buffer); + + col_stat.index_bytes = buffer->getMaterializedBytes(); + + // Memorize what kind of vector index it is, so that we can correctly restore it when reading. + col_stat.vector_index = dtpb::ColumnVectorIndexInfo{}; + col_stat.vector_index->set_index_kind(String(magic_enum::enum_name(stream->vector_index->kind))); + col_stat.vector_index->set_distance_metric( + String(magic_enum::enum_name(stream->vector_index->distance_metric))); + + MergedSubFileInfo info{ + fname, + merged_file.file_info.number, + merged_file.file_info.size, + col_stat.index_bytes}; + dmfile->merged_sub_file_infos[fname] = info; + + merged_file.file_info.size += col_stat.index_bytes; + buffer->next(); + } + // write mark into merged_file_writer if (!is_empty_file) { @@ -457,6 +495,11 @@ void DMFileWriter::finalizeColumn(ColId col_id, DataTypePtr type) #endif } } + + if (stream->vector_index) + { + RUNTIME_CHECK_MSG(false, "Vector index is not compatible with V1 and V2 format"); + } } }; type->enumerateStreams(callback, {}); diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h index dedab933763..960d448d093 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace DB { @@ -53,7 +54,8 @@ class DMFileWriter size_t max_compress_block_size, FileProviderPtr & file_provider, const WriteLimiterPtr & write_limiter_, - bool do_index) + bool do_index, + TiDB::VectorIndexInfoPtr do_vector_index) : plain_file(WriteBufferByFileProviderBuilder( dmfile->configuration.has_value(), file_provider, @@ -71,6 +73,7 @@ class DMFileWriter : std::unique_ptr( new CompressedWriteBuffer(*plain_file, compression_settings))) , minmaxes(do_index ? std::make_shared(*type) : nullptr) + , vector_index(do_vector_index ? VectorIndex::create(*do_vector_index) : nullptr) { if (!dmfile->useMetaV2()) { @@ -97,6 +100,7 @@ class DMFileWriter WriteBufferPtr compressed_buf; MinMaxIndexPtr minmaxes; + VectorIndexPtr vector_index; MarksInCompressedFilePtr marks; @@ -158,7 +162,7 @@ class DMFileWriter /// Add streams with specified column id. Since a single column may have more than one Stream, /// for example Nullable column has a NullMap column, we would track them with a mapping /// FileNameBase -> Stream. - void addStreams(ColId col_id, DataTypePtr type, bool do_index); + void addStreams(ColId col_id, DataTypePtr type, bool do_index, TiDB::VectorIndexInfoPtr do_vector_index); WriteBufferFromFileBasePtr createMetaFile(); WriteBufferFromFileBasePtr createMetaV2File(); diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp new file mode 100644 index 00000000000..c263643c250 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp @@ -0,0 +1,133 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +namespace DB::DM +{ + +std::vector VectorColumnFromIndexReader::calcPackStartRowID(const DMFile::PackStats & pack_stats) +{ + std::vector pack_start_rowid(pack_stats.size()); + UInt32 rowid = 0; + for (size_t i = 0, i_max = pack_stats.size(); i < i_max; i++) + { + pack_start_rowid[i] = rowid; + rowid += pack_stats[i].rows; + } + return pack_start_rowid; +} + +MutableColumnPtr VectorColumnFromIndexReader::calcResultsByPack( + std::vector && results, + const DMFile::PackStats & pack_stats, + const std::vector & pack_start_rowid) +{ + auto column = ColumnArray::create(ColumnUInt32::create()); + + // results must be in ascending order. + std::sort(results.begin(), results.end()); + + std::vector offsets_in_pack; + size_t results_it = 0; + const size_t results_it_max = results.size(); + for (size_t pack_id = 0, pack_id_max = pack_start_rowid.size(); pack_id < pack_id_max; pack_id++) + { + offsets_in_pack.clear(); + + UInt32 pack_start = pack_start_rowid[pack_id]; + UInt32 pack_end = pack_start + pack_stats[pack_id].rows; + + while (results_it < results_it_max // + && results[results_it] >= pack_start // + && results[results_it] < pack_end) + { + offsets_in_pack.push_back(results[results_it] - pack_start); + results_it++; + } + + column->insertData( + reinterpret_cast(offsets_in_pack.data()), + offsets_in_pack.size() * sizeof(UInt32)); + } + + RUNTIME_CHECK_MSG(results_it == results_it_max, "All packs has been visited but not all results are consumed"); + + return column; +} + +void VectorColumnFromIndexReader::read(MutableColumnPtr & column, size_t start_pack_id, UInt32 read_rows) +{ + std::vector value; + const auto * results_by_pack = checkAndGetColumn(this->results_by_pack.get()); + checkAndGetColumn(column.get()); + + size_t pack_id = start_pack_id; + UInt32 remaining_rows_in_pack = pack_stats[pack_id].rows; + + while (read_rows > 0) + { + if (remaining_rows_in_pack == 0) + { + // If this pack is drained but we still need to read more rows, let's read from next pack. + pack_id++; + RUNTIME_CHECK(pack_id < pack_stats.size()); + remaining_rows_in_pack = pack_stats[pack_id].rows; + } + + UInt32 expect_result_rows = std::min(remaining_rows_in_pack, read_rows); + UInt32 filled_result_rows = 0; + + auto offsets_in_pack = results_by_pack->getDataAt(pack_id); + auto offsets_in_pack_n = results_by_pack->sizeAt(pack_id); + RUNTIME_CHECK(offsets_in_pack.size == offsets_in_pack_n * sizeof(UInt32)); + + // Note: offsets_in_pack_n may be 0, means there is no results in this pack. + for (size_t i = 0; i < offsets_in_pack_n; ++i) + { + UInt32 offset_in_pack = reinterpret_cast(offsets_in_pack.data)[i]; + RUNTIME_CHECK(filled_result_rows <= offset_in_pack); + if (offset_in_pack > filled_result_rows) + { + UInt32 nulls = offset_in_pack - filled_result_rows; + // Insert [] if column is Not Null, or NULL if column is Nullable + column->insertManyDefaults(nulls); + filled_result_rows += nulls; + } + RUNTIME_CHECK(filled_result_rows == offset_in_pack); + + // TODO: We could fill multiple rows if rowid is continuous. + VectorIndex::Key rowid = pack_start_rowid[pack_id] + offset_in_pack; + index->get(rowid, value); + column->insertData(reinterpret_cast(value.data()), value.size() * sizeof(Float32)); + filled_result_rows++; + } + + if (filled_result_rows < expect_result_rows) + { + size_t nulls = expect_result_rows - filled_result_rows; + // Insert [] if column is Not Null, or NULL if column is Nullable + column->insertManyDefaults(nulls); + filled_result_rows += nulls; + } + + RUNTIME_CHECK(filled_result_rows == expect_result_rows); + remaining_rows_in_pack -= filled_result_rows; + read_rows -= filled_result_rows; + } +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h new file mode 100644 index 00000000000..14e807385bd --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h @@ -0,0 +1,77 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB::DM +{ + +/** + * @brief VectorColumnFromIndexReader reads vector column data from the index + * while maintaining the same column layout as if it was read from the DMFile. + * For example, when we want to read vector column data of row id [1, 5, 10], + * this reader will return [NULL, VEC, NULL, NULL, NULL, VEC, ....]. + * + * Note: The term "row id" in this class refers to the row offset in this DMFile. + * It is a file-level row id, not a global row id. + */ +class VectorColumnFromIndexReader +{ +private: + const DMFilePtr dmfile; // Keep a reference of dmfile to keep pack_stats valid. + const DMFile::PackStats & pack_stats; + const std::vector pack_start_rowid; + + const VectorIndexPtr index; + /// results_by_pack[i]=[a,b,c...] means pack[i]'s row offset [a,b,c,...] is contained in the result set. + /// The rowid of a is pack_start_rowid[i]+a. + MutableColumnPtr /* ColumnArray of UInt32 */ results_by_pack; + +private: + static std::vector calcPackStartRowID(const DMFile::PackStats & pack_stats); + + static MutableColumnPtr calcResultsByPack( + std::vector && results, + const DMFile::PackStats & pack_stats, + const std::vector & pack_start_rowid); + +public: + /// VectorIndex::Key is the offset of the row in the DMFile (file-level row id), + /// including NULLs and delete marks. + explicit VectorColumnFromIndexReader( + const DMFilePtr & dmfile_, + const VectorIndexPtr & index_, + std::vector && results_) + : dmfile(dmfile_) + , pack_stats(dmfile_->getPackStats()) + , pack_start_rowid(calcPackStartRowID(pack_stats)) + , index(index_) + , results_by_pack(calcResultsByPack(std::move(results_), pack_stats, pack_start_rowid)) + {} + + void read(MutableColumnPtr & column, size_t start_pack_id, UInt32 read_rows); +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader_fwd.h b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader_fwd.h new file mode 100644 index 00000000000..c5fcb54abe6 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader_fwd.h @@ -0,0 +1,25 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class VectorColumnFromIndexReader; +using VectorColumnFromIndexReaderPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto b/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto index b1d02e7e584..77256bfff43 100644 --- a/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto +++ b/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto @@ -49,6 +49,13 @@ message ChecksumConfig { repeated ChecksumDebugInfo debug_info = 5; } +// Note: This message does not contain all fields of VectorIndexInfo, +// because this message is only used for reading the vector index carried with the column. +message ColumnVectorIndexInfo { + optional string index_kind = 1; + optional string distance_metric = 2; +} + message ColumnStat { optional int64 col_id = 1; optional string type_name = 2; @@ -61,6 +68,8 @@ message ColumnStat { optional uint64 index_bytes = 9; optional uint64 array_sizes_bytes = 10; optional uint64 array_sizes_mark_bytes = 11; + + optional ColumnVectorIndexInfo vector_index = 101; } message ColumnStats { diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp index 12dc0b894dd..fc12fbf6ffb 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace DB::DM { @@ -50,4 +51,9 @@ RSOperatorPtr createIsNull(const Attr & attr) RSOperatorPtr createUnsupported(const String & content, const String & reason, bool is_not) { return std::make_shared(content, reason, is_not); } // clang-format on +RSOperatorPtr wrapWithANNQueryInfo(const RSOperatorPtr & op, const ANNQueryInfoPtr & ann_query_info) +{ + return std::make_shared(op, ann_query_info); +} + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h index 14387e00d26..db6eb817a76 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace DB::DM { @@ -142,4 +143,7 @@ RSOperatorPtr createIsNull(const Attr & attr); // RSOperatorPtr createUnsupported(const String & content, const String & reason, bool is_not); +/// Wrap with a ANNQueryInfo +RSOperatorPtr wrapWithANNQueryInfo(const RSOperatorPtr & op, const ANNQueryInfoPtr & ann_query_info); + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Filter/WithANNQueryInfo.h b/dbms/src/Storages/DeltaMerge/Filter/WithANNQueryInfo.h new file mode 100644 index 00000000000..b6b4675d52f --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Filter/WithANNQueryInfo.h @@ -0,0 +1,65 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::DM +{ + +// HACK: We reused existing RSOperator path to pass ANNQueryInfo. +// This is for minimizing changed files in the Serverless. +// When we port back the implementation to open-source version, we should extract the ANNQueryInfo out. +class WithANNQueryInfo : public RSOperator +{ +public: + const RSOperatorPtr child; + const ANNQueryInfoPtr ann_query_info; + + explicit WithANNQueryInfo(const RSOperatorPtr & child_, const ANNQueryInfoPtr & ann_query_info_) + : RSOperator({child_}) + , child(child_) + , ann_query_info(ann_query_info_) + {} + + String name() override { return "ann"; } + + Attrs getAttrs() override + { + if (children[0]) + return children[0]->getAttrs(); + else + return {}; + } + + String toDebugString() override + { + if (children[0]) + return children[0]->toDebugString(); + else + return ""; + } + + RSResults roughCheck(size_t start_pack, size_t pack_count, const RSCheckParam & param) override + { + if (children[0]) + return children[0]->roughCheck(start_pack, pack_count, param); + else + return RSResults(pack_count, RSResult::Unknown); + } +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/RSIndex.h b/dbms/src/Storages/DeltaMerge/Index/RSIndex.h index 1cfa9aebf31..b171e146286 100644 --- a/dbms/src/Storages/DeltaMerge/Index/RSIndex.h +++ b/dbms/src/Storages/DeltaMerge/Index/RSIndex.h @@ -15,6 +15,7 @@ #pragma once #include +#include namespace DB { @@ -35,12 +36,18 @@ struct RSIndex DataTypePtr type; MinMaxIndexPtr minmax; EqualIndexPtr equal; + VectorIndexPtr vector; // TODO: Actually this is not a rough index. We put it here for convenience. RSIndex(const DataTypePtr & type_, const MinMaxIndexPtr & minmax_) : type(type_) , minmax(minmax_) {} + RSIndex(const DataTypePtr & type_, const VectorIndexPtr & vector_) + : type(type_) + , vector(vector_) + {} + RSIndex(const DataTypePtr & type_, const MinMaxIndexPtr & minmax_, const EqualIndexPtr & equal_) : type(type_) , minmax(minmax_) @@ -52,4 +59,4 @@ using ColumnIndexes = std::unordered_map; } // namespace DM -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp new file mode 100644 index 00000000000..652f5168d9d --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp @@ -0,0 +1,88 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int INCORRECT_QUERY; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +bool VectorIndex::isSupportedType(const IDataType & type) +{ + const auto * nullable = checkAndGetDataType(&type); + if (nullable) + return checkDataTypeArray(&*nullable->getNestedType()); + + return checkDataTypeArray(&type); +} + +VectorIndexPtr VectorIndex::create(const TiDB::VectorIndexInfo & index_info) +{ + RUNTIME_CHECK(index_info.dimension > 0); + RUNTIME_CHECK(index_info.dimension <= std::numeric_limits::max()); + + switch (index_info.kind) + { + case TiDB::VectorIndexKind::HNSW: + switch (index_info.distance_metric) + { + case TiDB::DistanceMetric::L2: + return std::make_shared>(index_info.dimension); + case TiDB::DistanceMetric::COSINE: + return std::make_shared>(index_info.dimension); + default: + throw Exception( + ErrorCodes::INCORRECT_QUERY, + "Unsupported vector index distance metric {}", + index_info.distance_metric); + } + default: + throw Exception(ErrorCodes::INCORRECT_QUERY, "Unsupported vector index {}", index_info.kind); + } +} + +VectorIndexPtr VectorIndex::load(TiDB::VectorIndexKind kind, TiDB::DistanceMetric distance_metric, ReadBuffer & istr) +{ + RUNTIME_CHECK(kind != TiDB::VectorIndexKind::INVALID); + RUNTIME_CHECK(distance_metric != TiDB::DistanceMetric::INVALID); + + switch (kind) + { + case TiDB::VectorIndexKind::HNSW: + switch (distance_metric) + { + case TiDB::DistanceMetric::L2: + return VectorIndexHNSW::deserializeBinary(istr); + case TiDB::DistanceMetric::COSINE: + return VectorIndexHNSW::deserializeBinary(istr); + default: + throw Exception( + ErrorCodes::INCORRECT_QUERY, + "Unsupported vector index distance metric {}", + distance_metric); + } + default: + throw Exception(ErrorCodes::INCORRECT_QUERY, "Unsupported vector index {}", kind); + } +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h new file mode 100644 index 00000000000..d249d7a91e7 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h @@ -0,0 +1,106 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace DM +{ + +class VectorIndex +{ +public: + /// The key is the row's offset in the DMFile. + using Key = UInt32; + + /// True bit means the row is valid and should be kept in the search result. + /// False bit lets the row filtered out and will search for more results. + using RowFilter = BitmapFilterView; + + struct SearchStatistics + { + size_t visited_nodes = 0; + size_t discarded_nodes = 0; // Rows filtered out by MVCC + }; + + static bool isSupportedType(const IDataType & type); + + static VectorIndexPtr create(const TiDB::VectorIndexInfo & index_info); + + static VectorIndexPtr load(TiDB::VectorIndexKind kind, TiDB::DistanceMetric distance_metric, ReadBuffer & istr); + + VectorIndex(TiDB::VectorIndexKind kind_, TiDB::DistanceMetric distance_metric_) + : kind(kind_) + , distance_metric(distance_metric_) + {} + + virtual ~VectorIndex() = default; + + virtual void addBlock(const IColumn & column, const ColumnVector * del_mark) = 0; + + virtual void serializeBinary(WriteBuffer & ostr) const = 0; + + virtual size_t memoryUsage() const = 0; + + virtual std::vector search( // + const ANNQueryInfoPtr & queryInfo, + const RowFilter & valid_rows, + SearchStatistics & statistics) const + = 0; + + // Get the value (i.e. vector content) of a Key. + virtual void get(Key key, std::vector & out) const = 0; + +public: + const TiDB::VectorIndexKind kind; + const TiDB::DistanceMetric distance_metric; +}; + +struct VectorIndexWeightFunction +{ + size_t operator()(const String &, const VectorIndex & index) const { return index.memoryUsage(); } +}; + +class VectorIndexCache : public LRUCache, VectorIndexWeightFunction> +{ +private: + using Base = LRUCache, VectorIndexWeightFunction>; + +public: + explicit VectorIndexCache(size_t max_size_in_bytes) + : Base(max_size_in_bytes) + {} + + template + MappedPtr getOrSet(const Key & key, LoadFunc && load) + { + auto result = Base::getOrSet(key, load); + return result.first; + } +}; + +} // namespace DM + +} // namespace DB diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp new file mode 100644 index 00000000000..b26da737bf4 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp @@ -0,0 +1,226 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include + +namespace DB::ErrorCodes +{ +extern const int INCORRECT_DATA; +extern const int INCORRECT_QUERY; +extern const int CANNOT_ALLOCATE_MEMORY; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +template +USearchIndexWithSerialization::USearchIndexWithSerialization(size_t dimensions) + : Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric))) +{} + +template +void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const +{ + auto callback = [&ostr](void * from, size_t n) { + ostr.write(reinterpret_cast(from), n); + return true; + }; + Base::save_to_stream(callback); +} + +template +void USearchIndexWithSerialization::deserialize(ReadBuffer & istr) +{ + auto callback = [&istr](void * from, size_t n) { + istr.readStrict(reinterpret_cast(from), n); + return true; + }; + Base::load_from_stream(callback); +} + +template class USearchIndexWithSerialization; +template class USearchIndexWithSerialization; + +constexpr TiDB::DistanceMetric toTiDBDistanceMetric(unum::usearch::metric_kind_t metric) +{ + switch (metric) + { + case unum::usearch::metric_kind_t::l2sq_k: + return TiDB::DistanceMetric::L2; + case unum::usearch::metric_kind_t::cos_k: + return TiDB::DistanceMetric::COSINE; + default: + return TiDB::DistanceMetric::INVALID; + } +} + +constexpr tipb::VectorDistanceMetric toTiDBQueryDistanceMetric(unum::usearch::metric_kind_t metric) +{ + switch (metric) + { + case unum::usearch::metric_kind_t::l2sq_k: + return tipb::VectorDistanceMetric::L2; + case unum::usearch::metric_kind_t::cos_k: + return tipb::VectorDistanceMetric::Cosine; + default: + return tipb::VectorDistanceMetric::InvalidMetric; + } +} + +template +VectorIndexHNSW::VectorIndexHNSW(UInt32 dimensions_) + : VectorIndex(TiDB::VectorIndexKind::HNSW, toTiDBDistanceMetric(Metric)) + , dimensions(dimensions_) + , index(std::make_shared>(static_cast(dimensions_))) +{} + +template +void VectorIndexHNSW::addBlock(const IColumn & column, const ColumnVector * del_mark) +{ + // Note: column may be nullable. + const ColumnArray * col_array; + if (column.isColumnNullable()) + col_array = checkAndGetNestedColumn(&column); + else + col_array = checkAndGetColumn(&column); + + RUNTIME_CHECK(col_array != nullptr, column.getFamilyName()); + RUNTIME_CHECK(checkAndGetColumn>(col_array->getDataPtr().get()) != nullptr); + + const auto * del_mark_data = (!del_mark) ? nullptr : &(del_mark->getData()); + + if (!index->reserve(unum::usearch::ceil2(index->size() + column.size()))) + { + throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for HNSW index"); + } + + for (int i = 0, i_max = col_array->size(); i < i_max; ++i) + { + auto row_offset = added_rows; + added_rows++; + + // Ignore rows with del_mark, as the column values are not meaningful. + if (del_mark_data != nullptr && (*del_mark_data)[i]) + continue; + + // Ignore NULL values, as they are not meaningful to store in index. + if (column.isNullAt(i)) + continue; + + // Expect all data to have matching dimensions. + RUNTIME_CHECK(col_array->sizeAt(i) == dimensions); + + auto data = col_array->getDataAt(i); + RUNTIME_CHECK(data.size == dimensions * sizeof(Float32)); + + if (auto rc = index->add(row_offset, reinterpret_cast(data.data)); !rc) + throw Exception(ErrorCodes::INCORRECT_DATA, rc.error.release()); + } +} + +template +void VectorIndexHNSW::serializeBinary(WriteBuffer & ostr) const +{ + writeStringBinary(magic_enum::enum_name(kind), ostr); + writeStringBinary(magic_enum::enum_name(distance_metric), ostr); + writeIntBinary(dimensions, ostr); + index->serialize(ostr); +} + +template +VectorIndexPtr VectorIndexHNSW::deserializeBinary(ReadBuffer & istr) +{ + String kind; + readStringBinary(kind, istr); + RUNTIME_CHECK(magic_enum::enum_cast(kind) == TiDB::VectorIndexKind::HNSW); + + String distance_metric; + readStringBinary(distance_metric, istr); + RUNTIME_CHECK(magic_enum::enum_cast(distance_metric) == toTiDBDistanceMetric(Metric)); + + UInt32 dimensions; + readIntBinary(dimensions, istr); + + auto vi = std::make_shared>(dimensions); + vi->index->deserialize(istr); + return vi; +} + +template +std::vector VectorIndexHNSW::search( + const ANNQueryInfoPtr & queryInfo, + const RowFilter & valid_rows, + SearchStatistics & statistics) const +{ + RUNTIME_CHECK(queryInfo->ref_vec_f32().size() >= sizeof(UInt32)); + auto query_vec_size = readLittleEndian(queryInfo->ref_vec_f32().data()); + if (query_vec_size != dimensions) + throw Exception( + ErrorCodes::INCORRECT_QUERY, + "Query vector size {} does not match index dimensions {}", + query_vec_size, + dimensions); + + RUNTIME_CHECK(queryInfo->ref_vec_f32().size() >= sizeof(UInt32) + query_vec_size * sizeof(Float32)); + + if (queryInfo->distance_metric() != toTiDBQueryDistanceMetric(Metric)) + throw Exception( + ErrorCodes::INCORRECT_QUERY, + "Query distance metric {} does not match index distance metric {}", + tipb::VectorDistanceMetric_Name(queryInfo->distance_metric()), + tipb::VectorDistanceMetric_Name(toTiDBQueryDistanceMetric(Metric))); + + RUNTIME_CHECK(index != nullptr); + + auto predicate = [&valid_rows, &statistics](USearchIndexWithSerialization::member_cref_t const & member) { + statistics.visited_nodes++; + if (!valid_rows[member.key]) + statistics.discarded_nodes++; + return valid_rows[member.key]; + }; + + // TODO: Support efSearch. + auto result = index->search( // + reinterpret_cast(queryInfo->ref_vec_f32().data() + sizeof(UInt32)), + queryInfo->top_k(), + predicate); + std::vector keys(result.size()); + result.dump_to(keys.data()); + + // For some reason usearch does not always do the predicate for all search results. + // So we need to filter again. + keys.erase( + std::remove_if(keys.begin(), keys.end(), [&valid_rows](Key key) { return !valid_rows[key]; }), + keys.end()); + + return keys; +} + +template +void VectorIndexHNSW::get(Key key, std::vector & out) const +{ + out.resize(dimensions); + index->get(key, out.data()); +} + +template class VectorIndexHNSW; +template class VectorIndexHNSW; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h new file mode 100644 index 00000000000..2663f0c0e18 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h @@ -0,0 +1,67 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::DM +{ + +using USearchImplType + = unum::usearch::index_dense_gt; + +template +class USearchIndexWithSerialization : public USearchImplType +{ + using Base = USearchImplType; + +public: + explicit USearchIndexWithSerialization(size_t dimensions); + void serialize(WriteBuffer & ostr) const; + void deserialize(ReadBuffer & istr); +}; + +template +using USearchIndexWithSerializationPtr = std::shared_ptr>; + +template +class VectorIndexHNSW : public VectorIndex +{ +public: + explicit VectorIndexHNSW(UInt32 dimensions_); + + void addBlock(const IColumn & column, const ColumnVector * del_mark) override; + + void serializeBinary(WriteBuffer & ostr) const override; + static VectorIndexPtr deserializeBinary(ReadBuffer & istr); + + size_t memoryUsage() const override { return index->memory_usage(); } + + std::vector search( // + const ANNQueryInfoPtr & queryInfo, + const RowFilter & valid_rows, + SearchStatistics & statistics) const override; + + void get(Key key, std::vector & out) const override; + +private: + const UInt32 dimensions; + const USearchIndexWithSerializationPtr index; + + UInt64 added_rows = 0; // Includes nulls and deletes. Used as the index key. +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h new file mode 100644 index 00000000000..3912b7671d5 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h @@ -0,0 +1,2241 @@ +/** + * @brief This is a modified version of usearch's index_dense.hpp. + * It supports predicate fn when doing the vector search. + * + * Original implementation: https://github.com/unum-cloud/usearch/blob/v2.9.1/include/usearch/index_dense.hpp + */ + +// NOLINTBEGIN(readability-*,google-*,modernize-use-auto) + +#pragma once +#include // `aligned_alloc` + +#include // `std::function` +#include // `std::iota` +#include // `std::thread` +#include +#include +#include // `std::vector` + +#if defined(USEARCH_DEFINED_CPP17) +#include // `std::shared_mutex` +#endif + +namespace unum +{ +namespace usearch +{ + +template +class index_dense_gt; + +/** + * @brief The "magic" sequence helps infer the type of the file. + * USearch indexes start with the "usearch" string. + */ +constexpr char const * default_magic() +{ + return "usearch"; +} + +using index_dense_head_buffer_t = byte_t[64]; + +static_assert(sizeof(index_dense_head_buffer_t) == 64, "File header should be exactly 64 bytes"); + +/** + * @brief Serialized binary representations of the USearch index start with metadata. + * Metadata is parsed into a `index_dense_head_t`, containing the USearch package version, + * and the properties of the index. + * + * It uses: 13 bytes for file versioning, 22 bytes for structural information = 35 bytes. + * The following 24 bytes contain binary size of the graph, of the vectors, and the checksum, + * leaving 5 bytes at the end vacant. + */ +struct index_dense_head_t +{ + // Versioning: + using magic_t = char[7]; + using version_t = std::uint16_t; + + // Versioning: 7 + 2 * 3 = 13 bytes + char const * magic; + misaligned_ref_gt version_major; + misaligned_ref_gt version_minor; + misaligned_ref_gt version_patch; + + // Structural: 4 * 3 = 12 bytes + misaligned_ref_gt kind_metric; + misaligned_ref_gt kind_scalar; + misaligned_ref_gt kind_key; + misaligned_ref_gt kind_compressed_slot; + + // Population: 8 * 3 = 24 bytes + misaligned_ref_gt count_present; + misaligned_ref_gt count_deleted; + misaligned_ref_gt dimensions; + misaligned_ref_gt multi; + + index_dense_head_t(byte_t * ptr) noexcept + : magic((char const *)exchange(ptr, ptr + sizeof(magic_t))) + , // + version_major(exchange(ptr, ptr + sizeof(version_t))) + , // + version_minor(exchange(ptr, ptr + sizeof(version_t))) + , // + version_patch(exchange(ptr, ptr + sizeof(version_t))) + , // + kind_metric(exchange(ptr, ptr + sizeof(metric_kind_t))) + , // + kind_scalar(exchange(ptr, ptr + sizeof(scalar_kind_t))) + , // + kind_key(exchange(ptr, ptr + sizeof(scalar_kind_t))) + , // + kind_compressed_slot(exchange(ptr, ptr + sizeof(scalar_kind_t))) + , // + count_present(exchange(ptr, ptr + sizeof(std::uint64_t))) + , // + count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))) + , // + dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))) + , // + multi(exchange(ptr, ptr + sizeof(bool))) + {} +}; + +struct index_dense_head_result_t +{ + index_dense_head_buffer_t buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_head_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } +}; + +struct index_dense_config_t : public index_config_t +{ + std::size_t expansion_add = default_expansion_add(); + std::size_t expansion_search = default_expansion_search(); + bool exclude_vectors = false; + bool multi = false; + + /** + * @brief Allows you to reduce RAM consumption by avoiding + * reverse-indexing keys-to-vectors, and only keeping + * the vectors-to-keys mappings. + * + * ! This configuration parameter doesn't affect the serialized file, + * ! and is not preserved between runs. Makes sense for small vector + * ! representations that fit ina single cache line. + */ + bool enable_key_lookups = true; + + index_dense_config_t(index_config_t base) noexcept + : index_config_t(base) + {} + + index_dense_config_t( + std::size_t c = default_connectivity(), + std::size_t ea = default_expansion_add(), + std::size_t es = default_expansion_search()) noexcept + : index_config_t(c) + , expansion_add(ea ? ea : default_expansion_add()) + , expansion_search(es ? es : default_expansion_search()) + {} +}; + +struct index_dense_clustering_config_t +{ + std::size_t min_clusters = 0; + std::size_t max_clusters = 0; + enum mode_t + { + merge_smallest_k, + merge_closest_k, + } mode = merge_smallest_k; +}; + +struct index_dense_serialization_config_t +{ + bool exclude_vectors = false; + bool use_64_bit_dimensions = false; +}; + +struct index_dense_copy_config_t : public index_copy_config_t +{ + bool force_vector_copy = true; + + index_dense_copy_config_t() = default; + index_dense_copy_config_t(index_copy_config_t base) noexcept + : index_copy_config_t(base) + {} +}; + +struct index_dense_metadata_result_t +{ + index_dense_serialization_config_t config; + index_dense_head_buffer_t head_buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_metadata_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } + + index_dense_metadata_result_t() noexcept + : config() + , head_buffer() + , head(head_buffer) + , error() + {} + + index_dense_metadata_result_t(index_dense_metadata_result_t && other) noexcept + : config() + , head_buffer() + , head(head_buffer) + , error(std::move(other.error)) + { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + } + + index_dense_metadata_result_t & operator=(index_dense_metadata_result_t && other) noexcept + { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + error = std::move(other.error); + return *this; + } +}; + +/** + * @brief Extracts metadata from a pre-constructed index on disk, + * without loading it or mapping the whole binary file. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_path(char const * file_path) noexcept +{ + index_dense_metadata_result_t result; + std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); + if (!file) + return result.failed(std::strerror(errno)); + + // Read the header + std::size_t read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + if (std::fseek(file.get(), 0L, SEEK_END) != 0) + return result.failed("Can't infer file size"); + + // Check if it starts with 32-bit + std::size_t const file_size = std::ftell(file.get()); + + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) + { + if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) + { + if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if it starts with 64-bit + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Extracts metadata from a pre-constructed index serialized into an in-memory buffer. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_buffer( + memory_mapped_file_t file, + std::size_t offset = 0) noexcept +{ + index_dense_metadata_result_t result; + + // Read the header + if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) + return result.failed("End of file reached!"); + + byte_t * const file_data = file.data() + offset; + std::size_t const file_size = file.size() - offset; + std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + // Check if it starts with 32-bit + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) + { + std::memcpy(&result.head_buffer, file_data + offset_if_u32, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) + { + std::memcpy(&result.head_buffer, file_data + offset_if_u64, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Oversimplified type-punned index for equidimensional vectors + * with automatic @b down-casting, hardware-specific @b SIMD metrics, + * and ability to @b remove existing vectors, common in Semantic Caching + * applications. + * + * @section Serialization + * + * The serialized binary form of `index_dense_gt` is made up of three parts: + * 1. Binary matrix, aka the `.bbin` part, + * 2. Metadata about used metrics, number of used vs free slots, + * 3. The HNSW index in a binary form. + * The first (1.) generally starts with 2 integers - number of rows (vectors) and @b single-byte columns. + * The second (2.) starts with @b "usearch"-magic-string, used to infer the file type on open. + * The third (3.) is implemented by the underlying `index_gt` class. + */ +template // +class index_dense_gt +{ +public: + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using distance_t = distance_punned_t; + using metric_t = metric_punned_t; + + using member_ref_t = member_ref_gt; + using member_cref_t = member_cref_gt; + + using head_t = index_dense_head_t; + using head_buffer_t = index_dense_head_buffer_t; + using head_result_t = index_dense_head_result_t; + + using serialization_config_t = index_dense_serialization_config_t; + + using dynamic_allocator_t = aligned_allocator_gt; + using tape_allocator_t = memory_mapping_allocator_gt<64>; + +private: + /// @brief Schema: input buffer, bytes in input buffer, output buffer. + using cast_t = std::function; + /// @brief Punned index. + using index_t = index_gt< // + distance_t, + vector_key_t, + compressed_slot_t, // + dynamic_allocator_t, + tape_allocator_t>; + using index_allocator_t = aligned_allocator_gt; + + using member_iterator_t = typename index_t::member_iterator_t; + using member_citerator_t = typename index_t::member_citerator_t; + + /// @brief Punned metric object. + class metric_proxy_t + { + index_dense_gt const * index_ = nullptr; + + public: + metric_proxy_t(index_dense_gt const & index) noexcept + : index_(&index) + {} + + inline distance_t operator()(byte_t const * a, member_cref_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } + + inline distance_t operator()(byte_t const * a, member_citerator_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept + { + return f(v(a), v(b)); + } + + inline distance_t operator()(byte_t const * a, byte_t const * b) const noexcept { return f(a, b); } + + inline byte_t const * v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline byte_t const * v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline distance_t f(byte_t const * a, byte_t const * b) const noexcept { return index_->metric_(a, b); } + }; + + index_dense_config_t config_; + index_t * typed_ = nullptr; + + mutable std::vector cast_buffer_; + struct casts_t + { + cast_t from_b1x8; + cast_t from_i8; + cast_t from_f16; + cast_t from_f32; + cast_t from_f64; + + cast_t to_b1x8; + cast_t to_i8; + cast_t to_f16; + cast_t to_f32; + cast_t to_f64; + } casts_; + + /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. + metric_t metric_; + + using vectors_tape_allocator_t = memory_mapping_allocator_gt<8>; + /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. + vectors_tape_allocator_t vectors_tape_allocator_; + + /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. + mutable std::vector vectors_lookup_; + + /// @brief Originally forms and array of integers [0, threads], marking all + mutable std::vector available_threads_; + + /// @brief Mutex, controlling concurrent access to `available_threads_`. + mutable std::mutex available_threads_mutex_; + +#if defined(USEARCH_DEFINED_CPP17) + using shared_mutex_t = std::shared_mutex; +#else + using shared_mutex_t = unfair_shared_mutex_t; +#endif + using shared_lock_t = shared_lock_gt; + using unique_lock_t = std::unique_lock; + + struct key_and_slot_t + { + vector_key_t key; + compressed_slot_t slot; + + bool any_slot() const { return slot == default_free_value(); } + static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } + }; + + struct lookup_key_hash_t + { + using is_transparent = void; + std::size_t operator()(key_and_slot_t const & k) const noexcept { return std::hash{}(k.key); } + std::size_t operator()(vector_key_t const & k) const noexcept { return std::hash{}(k); } + }; + + struct lookup_key_same_t + { + using is_transparent = void; + bool operator()(key_and_slot_t const & a, vector_key_t const & b) const noexcept { return a.key == b; } + bool operator()(vector_key_t const & a, key_and_slot_t const & b) const noexcept { return a == b.key; } + bool operator()(key_and_slot_t const & a, key_and_slot_t const & b) const noexcept { return a.key == b.key; } + }; + + /// @brief Multi-Map from keys to IDs, and allocated vectors. + flat_hash_multi_set_gt slot_lookup_; + + /// @brief Mutex, controlling concurrent access to `slot_lookup_`. + mutable shared_mutex_t slot_lookup_mutex_; + + /// @brief Ring-shaped queue of deleted entries, to be reused on future insertions. + ring_gt free_keys_; + + /// @brief Mutex, controlling concurrent access to `free_keys_`. + mutable std::mutex free_keys_mutex_; + + /// @brief A constant for the reserved key value, used to mark deleted entries. + vector_key_t free_key_ = default_free_value(); + +public: + using search_result_t = typename index_t::search_result_t; + using cluster_result_t = typename index_t::cluster_result_t; + using add_result_t = typename index_t::add_result_t; + using stats_t = typename index_t::stats_t; + using match_t = typename index_t::match_t; + + index_dense_gt() = default; + index_dense_gt(index_dense_gt && other) + : config_(std::move(other.config_)) + , + + typed_(exchange(other.typed_, nullptr)) + , // + cast_buffer_(std::move(other.cast_buffer_)) + , // + casts_(std::move(other.casts_)) + , // + metric_(std::move(other.metric_)) + , // + + vectors_tape_allocator_(std::move(other.vectors_tape_allocator_)) + , // + vectors_lookup_(std::move(other.vectors_lookup_)) + , // + + available_threads_(std::move(other.available_threads_)) + , // + slot_lookup_(std::move(other.slot_lookup_)) + , // + free_keys_(std::move(other.free_keys_)) + , // + free_key_(std::move(other.free_key_)) + {} // + + index_dense_gt & operator=(index_dense_gt && other) + { + swap(other); + return *this; + } + + /** + * @brief Swaps the contents of this index with another index. + * @param other The other index to swap with. + */ + void swap(index_dense_gt & other) + { + std::swap(config_, other.config_); + + std::swap(typed_, other.typed_); + std::swap(cast_buffer_, other.cast_buffer_); + std::swap(casts_, other.casts_); + std::swap(metric_, other.metric_); + + std::swap(vectors_tape_allocator_, other.vectors_tape_allocator_); + std::swap(vectors_lookup_, other.vectors_lookup_); + + std::swap(available_threads_, other.available_threads_); + std::swap(slot_lookup_, other.slot_lookup_); + std::swap(free_keys_, other.free_keys_); + std::swap(free_key_, other.free_key_); + } + + ~index_dense_gt() + { + if (typed_) + typed_->~index_t(); + index_allocator_t{}.deallocate(typed_, 1); + typed_ = nullptr; + } + + /** + * @brief Constructs an instance of ::index_dense_gt. + * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. + * @param[in] config The index configuration (optional). + * @param[in] free_key The key used for freed vectors (optional). + * @return An instance of ::index_dense_gt. + */ + static index_dense_gt make( // + metric_t metric, // + index_dense_config_t config = {}, // + vector_key_t free_key = default_free_value()) + { + scalar_kind_t scalar_kind = metric.scalar_kind(); + std::size_t hardware_threads = std::thread::hardware_concurrency(); + + index_dense_gt result; + result.config_ = config; + result.cast_buffer_.resize(hardware_threads * metric.bytes_per_vector()); + result.casts_ = make_casts_(scalar_kind); + result.metric_ = metric; + result.free_key_ = free_key; + + // Fill the thread IDs. + result.available_threads_.resize(hardware_threads); + std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); + + // Available since C11, but only C++17, so we use the C version. + index_t * raw = index_allocator_t{}.allocate(1); + new (raw) index_t(config); + result.typed_ = raw; + return result; + } + + static index_dense_gt make(char const * path, bool view = false) + { + index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); + if (!meta) + return {}; + metric_punned_t metric(meta.head.dimensions, meta.head.kind_metric, meta.head.kind_scalar); + index_dense_gt result = make(metric); + if (!result) + return result; + if (view) + result.view(path); + else + result.load(path); + return result; + } + + explicit operator bool() const { return typed_; } + std::size_t connectivity() const { return typed_->connectivity(); } + std::size_t size() const { return typed_->size() - free_keys_.size(); } + std::size_t capacity() const { return typed_->capacity(); } + std::size_t max_level() const noexcept { return typed_->max_level(); } + index_dense_config_t const & config() const { return config_; } + index_limits_t const & limits() const { return typed_->limits(); } + bool multi() const { return config_.multi; } + + // The metric and its properties + metric_t const & metric() const { return metric_; } + void change_metric(metric_t metric) { metric_ = std::move(metric); } + + scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } + std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } + std::size_t dimensions() const noexcept { return metric_.dimensions(); } + + // Fetching and changing search criteria + std::size_t expansion_add() const { return config_.expansion_add; } + std::size_t expansion_search() const { return config_.expansion_search; } + void change_expansion_add(std::size_t n) { config_.expansion_add = n; } + void change_expansion_search(std::size_t n) { config_.expansion_search = n; } + + member_citerator_t cbegin() const { return typed_->cbegin(); } + member_citerator_t cend() const { return typed_->cend(); } + member_citerator_t begin() const { return typed_->begin(); } + member_citerator_t end() const { return typed_->end(); } + member_iterator_t begin() { return typed_->begin(); } + member_iterator_t end() { return typed_->end(); } + + stats_t stats() const { return typed_->stats(); } + stats_t stats(std::size_t level) const { return typed_->stats(level); } + stats_t stats(stats_t * stats_per_level, std::size_t max_level) const + { + return typed_->stats(stats_per_level, max_level); + } + + dynamic_allocator_t const & allocator() const { return typed_->dynamic_allocator(); } + vector_key_t const & free_key() const { return free_key_; } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage() const + { + return // + typed_->memory_usage(0) + // + typed_->tape_allocator().total_wasted() + // + typed_->tape_allocator().total_reserved() + // + vectors_tape_allocator_.total_allocated(); + } + + static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } + static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } + + struct aggregated_distances_t + { + std::size_t count = 0; + distance_t mean = infinite_distance(); + distance_t min = infinite_distance(); + distance_t max = infinite_distance(); + }; + + // clang-format off + add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_b1x8); } + add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_i8); } + add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f16); } + add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); } + add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); } + + template + search_result_t search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate = predicate_at{}, std::size_t thread = any_thread(), size_t expansion = default_expansion_search(), bool exact = false) const { return search_(vector, wanted, predicate, thread, expansion, exact, casts_.from_b1x8); } + template + search_result_t search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate = predicate_at{}, std::size_t thread = any_thread(), size_t expansion = default_expansion_search(), bool exact = false) const { return search_(vector, wanted, predicate, thread, expansion, exact, casts_.from_i8); } + template + search_result_t search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate = predicate_at{}, std::size_t thread = any_thread(), size_t expansion = default_expansion_search(), bool exact = false) const { return search_(vector, wanted, predicate, thread, expansion, exact, casts_.from_f16); } + template + search_result_t search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate = predicate_at{}, std::size_t thread = any_thread(), size_t expansion = default_expansion_search(), bool exact = false) const { return search_(vector, wanted, predicate, thread, expansion, exact, casts_.from_f32); } + template + search_result_t search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate = predicate_at{}, std::size_t thread = any_thread(), size_t expansion = default_expansion_search(), bool exact = false) const { return search_(vector, wanted, predicate, thread, expansion, exact, casts_.from_f64); } + + std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); } + std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); } + std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f16); } + std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f32); } + std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f64); } + + cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_b1x8); } + cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_i8); } + cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f16); } + cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f32); } + cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f64); } + + aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_b1x8); } + aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_i8); } + aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f16); } + aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f32); } + aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f64); } + // clang-format on + + /** + * @brief Computes the distance between two managed entities. + * If either key maps into more than one vector, will aggregate results + * exporting the mean, maximum, and minimum values. + */ + aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const + { + shared_lock_t lock(slot_lookup_mutex_); + aggregated_distances_t result; + if (!multi()) + { + auto a_it = slot_lookup_.find(key_and_slot_t::any_slot(a)); + auto b_it = slot_lookup_.find(key_and_slot_t::any_slot(b)); + bool a_missing = a_it == slot_lookup_.end(); + bool b_missing = b_it == slot_lookup_.end(); + if (a_missing || b_missing) + return result; + + key_and_slot_t a_key_and_slot = *a_it; + byte_t const * a_vector = vectors_lookup_[a_key_and_slot.slot]; + key_and_slot_t b_key_and_slot = *b_it; + byte_t const * b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean = result.min = result.max = a_b_distance; + result.count = 1; + return result; + } + + auto a_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(a)); + auto b_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(b)); + bool a_missing = a_range.first == a_range.second; + bool b_missing = b_range.first == b_range.second; + if (a_missing || b_missing) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (a_range.first != a_range.second) + { + key_and_slot_t a_key_and_slot = *a_range.first; + byte_t const * a_vector = vectors_lookup_[a_key_and_slot.slot]; + while (b_range.first != b_range.second) + { + key_and_slot_t b_key_and_slot = *b_range.first; + byte_t const * b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++b_range.first; + } + ++a_range.first; + } + + result.mean /= result.count; + return result; + } + + /** + * @brief Identifies a node in a given `level`, that is the closest to the `key`. + */ + cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const + { + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + cluster_result_t result; + if (key_range.first == key_range.second) + return result.failed("Key missing!"); + + index_cluster_config_t cluster_config; + thread_lock_t lock = thread_lock_(thread); + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + metric_proxy_t metric{*this}; + auto allow = [=](member_cref_t const & member) noexcept { + return member.key != free_key_; + }; + + // Find the closest cluster for any vector under that key. + while (key_range.first != key_range.second) + { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const * vector_data = vectors_lookup_[key_and_slot.slot]; + cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); + if (!new_result) + return new_result; + if (new_result.cluster.distance < result.cluster.distance) + result = std::move(new_result); + + ++key_range.first; + } + return result; + } + + /** + * @brief Reserves memory for the index and the keyed lookup. + * @return `true` if the memory reservation was successful, `false` otherwise. + */ + bool reserve(index_limits_t limits) + { + { + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.reserve(limits.members); + vectors_lookup_.resize(limits.members); + } + return typed_->reserve(limits); + } + + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() + { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + typed_->clear(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() + { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + std::unique_lock available_threads_lock(available_threads_mutex_); + typed_->reset(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + + // Reset the thread IDs. + available_threads_.resize(std::thread::hardware_concurrency()); + std::iota(available_threads_.begin(), available_threads_.end(), 0ul); + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream( + output_callback_at && output, // + serialization_config_t config = {}, // + progress_at && progress = {}) const + { + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to put the vectors into the same file + if (!config.exclude_vectors) + { + // Save the matrix size + if (!config.use_64_bit_dimensions) + { + std::uint32_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + else + { + std::uint64_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + + // Dump the vectors one after another + for (std::uint64_t i = 0; i != matrix_rows; ++i) + { + byte_t * vector = vectors_lookup_[i]; + if (!output(vector, matrix_cols)) + return result.failed("Failed to serialize into stream"); + } + } + + // Augment metadata + { + index_dense_head_buffer_t buffer; + std::memset(buffer, 0, sizeof(buffer)); + index_dense_head_t head{buffer}; + std::memcpy(buffer, default_magic(), std::strlen(default_magic())); + + // Describe software version + using version_t = index_dense_head_t::version_t; + head.version_major = static_cast(USEARCH_VERSION_MAJOR); + head.version_minor = static_cast(USEARCH_VERSION_MINOR); + head.version_patch = static_cast(USEARCH_VERSION_PATCH); + + // Describes types used + head.kind_metric = metric_.metric_kind(); + head.kind_scalar = metric_.scalar_kind(); + head.kind_key = unum::usearch::scalar_kind(); + head.kind_compressed_slot = unum::usearch::scalar_kind(); + + head.count_present = size(); + head.count_deleted = typed_->size() - size(); + head.dimensions = dimensions(); + head.multi = multi(); + + if (!output(&buffer, sizeof(buffer))) + return result.failed("Failed to serialize into stream"); + } + + // Save the actual proximity graph + return typed_->save_to_stream(std::forward(output), std::forward(progress)); + } + + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length(serialization_config_t config = {}) const noexcept + { + std::size_t dimensions_length = 0; + std::size_t matrix_length = 0; + if (!config.exclude_vectors) + { + dimensions_length = config.use_64_bit_dimensions ? sizeof(std::uint64_t) * 2 : sizeof(std::uint32_t) * 2; + matrix_length = typed_->size() * metric_.bytes_per_vector(); + } + return dimensions_length + matrix_length + sizeof(index_dense_head_buffer_t) + typed_->serialized_length(); + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load_from_stream( + input_callback_at && input, // + serialization_config_t config = {}, // + progress_at && progress = {}) + { + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + // Infer the new index size + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to load the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) + { + // Save the matrix size + if (!config.use_64_bit_dimensions) + { + std::uint32_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 32-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + else + { + std::uint64_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 64-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + // Load the vectors one after another + vectors_lookup_.resize(matrix_rows); + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) + { + byte_t * vector = vectors_tape_allocator_.allocate(matrix_cols); + if (!input(vector, matrix_cols)) + return result.failed("Failed to read vectors"); + vectors_lookup_[slot] = vector; + } + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (!input(buffer, sizeof(buffer))) + return result.failed("Failed to read the index "); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + } + + // Pull the actual proximity graph + result = typed_->load_from_stream(std::forward(input), std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + reindex_keys_(); + return result; + } + + /** + * @brief Parses the index from file, without loading it into RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t view( + memory_mapped_file_t file, // + std::size_t offset = 0, + serialization_config_t config = {}, // + progress_at && progress = {}) + { + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Infer the new index size + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + span_punned_t vectors_buffer; + + // We may not want to fetch the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) + { + // Save the matrix size + if (!config.use_64_bit_dimensions) + { + std::uint32_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } + else + { + std::uint64_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } + vectors_buffer = {file.data() + offset, static_cast(matrix_rows * matrix_cols)}; + offset += vectors_buffer.size(); + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (file.size() - offset < sizeof(buffer)) + return result.failed("File is corrupted and lacks a header"); + + std::memcpy(buffer, file.data() + offset, sizeof(buffer)); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + offset += sizeof(buffer); + } + + // Pull the actual proximity graph + result = typed_->view(std::move(file), offset, std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + // Address the vectors + vectors_lookup_.resize(matrix_rows); + if (!config.exclude_vectors) + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) + vectors_lookup_[slot] = (byte_t *)vectors_buffer.data() + matrix_cols * slot; + + reindex_keys_(); + return result; + } + + /** + * @brief Saves the index to a file. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for exports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t save(output_file_t file, serialization_config_t config = {}, progress_at && progress = {}) + const + { + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const * buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + config, + std::forward(progress)); + + if (!stream_result) + { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save( + memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at && progress = {}) const + { + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const * buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + config, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at && progress = {}) + { + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void * buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + config, + std::forward(progress)); + + if (!stream_result) + { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t load( + memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at && progress = {}) + { + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void * buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + config, + std::forward(progress)); + + return stream_result; + } + + template + serialization_result_t save( + char const * file_path, // + serialization_config_t config = {}, // + progress_at && progress = {}) const + { + return save(output_file_t(file_path), config, std::forward(progress)); + } + + template + serialization_result_t load( + char const * file_path, // + serialization_config_t config = {}, // + progress_at && progress = {}) + { + return load(input_file_t(file_path), config, std::forward(progress)); + } + + /** + * @brief Checks if a vector with specified key is present. + * @return `true` if the key is present in the index, `false` otherwise. + */ + bool contains(vector_key_t key) const + { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.contains(key_and_slot_t::any_slot(key)); + } + + /** + * @brief Count the number of vectors with specified key present. + * @return Zero if nothing is found, a positive integer otherwise. + */ + std::size_t count(vector_key_t key) const + { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.count(key_and_slot_t::any_slot(key)); + } + + struct labeling_result_t + { + error_t error{}; + std::size_t completed{}; + + explicit operator bool() const noexcept { return !error; } + labeling_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Removes an entry with the specified key from the index. + * @param[in] key The key of the entry to remove. + * @return The ::labeling_result_t indicating the result of the removal operation. + * If the removal was successful, `result.completed` will be `true`. + * If the key was not found in the index, `result.completed` will be `false`. + * If an error occurred during the removal operation, `result.error` will contain an error message. + */ + labeling_result_t remove(vector_key_t key) + { + labeling_result_t result; + + unique_lock_t lookup_lock(slot_lookup_mutex_); + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + if (matching_slots.first == matching_slots.second) + return result; + + // Grow the removed entries ring, if needed + std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); + std::unique_lock free_lock(free_keys_mutex_); + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) + { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + } + slot_lookup_.erase(key); + result.completed = matching_count; + + return result; + } + + /** + * @brief Removes multiple entries with the specified keys from the index. + * @param[in] keys_begin The beginning of the keys range. + * @param[in] keys_end The ending of the keys range. + * @return The ::labeling_result_t indicating the result of the removal operation. + * `result.completed` will contain the number of keys that were successfully removed. + * `result.error` will contain an error message if an error occurred during the removal operation. + */ + template + labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) + { + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + std::unique_lock free_lock(free_keys_mutex_); + // Grow the removed entries ring, if needed + std::size_t matching_count = 0; + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) + matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it)); + + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // Remove them one-by-one + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) + { + vector_key_t key = *keys_it; + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + matching_count = 0; + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) + { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + ++matching_count; + } + + slot_lookup_.erase(key); + result.completed += matching_count; + } + + return result; + } + + /** + * @brief Renames an entry with the specified key to a new key. + * @param[in] from The current key of the entry to rename. + * @param[in] to The new key to assign to the entry. + * @return The ::labeling_result_t indicating the result of the rename operation. + * If the rename was successful, `result.completed` will be `true`. + * If the entry with the current key was not found, `result.completed` will be `false`. + */ + labeling_result_t rename(vector_key_t from, vector_key_t to) + { + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + + if (!multi() && slot_lookup_.contains(key_and_slot_t::any_slot(to))) + return result.failed("Renaming impossible, the key is already in use"); + + // The `from` may map to multiple entries + while (true) + { + key_and_slot_t key_and_slot_removed; + if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) + break; + + key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; + slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail + typed_->at(key_and_slot_removed.slot).key = to; + ++result.completed; + } + + return result; + } + + /** + * @brief Exports a range of keys for the vectors present in the index. + * @param[out] keys Pointer to the array where the keys will be exported. + * @param[in] offset The number of keys to skip. Useful for pagination. + * @param[in] limit The maximum number of keys to export, that can fit in ::keys. + */ + void export_keys(vector_key_t * keys, std::size_t offset, std::size_t limit) const + { + shared_lock_t lock(slot_lookup_mutex_); + offset = (std::min)(offset, slot_lookup_.size()); + slot_lookup_.for_each([&](key_and_slot_t const & key_and_slot) { + if (offset) + // Skip the first `offset` entries + --offset; + else if (limit) + { + *keys = key_and_slot.key; + ++keys; + --limit; + } + }); + } + + struct copy_result_t + { + index_dense_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Copies the ::index_dense_gt @b with all the data in it. + * @param config The copy configuration (optional). + * @return A copy of the ::index_dense_gt instance. + */ + copy_result_t copy(index_dense_copy_config_t config = {}) const + { + copy_result_t result = fork(); + if (!result) + return result; + + auto typed_result = typed_->copy(config); + if (!typed_result) + return result.failed(std::move(typed_result.error)); + + // Export the free (removed) slot numbers + index_dense_gt & copy = result.index; + if (!copy.free_keys_.reserve(free_keys_.size())) + return result.failed(std::move(typed_result.error)); + for (std::size_t i = 0; i != free_keys_.size(); ++i) + copy.free_keys_.push(free_keys_[i]); + + // Allocate buffers and move the vectors themselves + if (!config.force_vector_copy && copy.config_.exclude_vectors) + copy.vectors_lookup_ = vectors_lookup_; + else + { + copy.vectors_lookup_.resize(vectors_lookup_.size()); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); + if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.end(), nullptr)) + return result.failed("Out of memory!"); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); + } + + copy.slot_lookup_ = slot_lookup_; + *copy.typed_ = std::move(typed_result.index); + return result; + } + + /** + * @brief Copies the ::index_dense_gt model @b without any data. + * @return A similarly configured ::index_dense_gt instance. + */ + copy_result_t fork() const + { + copy_result_t result; + index_dense_gt & other = result.index; + + other.config_ = config_; + other.cast_buffer_ = cast_buffer_; + other.casts_ = casts_; + + other.metric_ = metric_; + other.available_threads_ = available_threads_; + other.free_key_ = free_key_; + + index_t * raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Can't allocate the index"); + + new (raw) index_t(config()); + other.typed_ = raw; + return result; + } + + struct compaction_result_t + { + error_t error{}; + std::size_t pruned_edges{}; + + explicit operator bool() const noexcept { return !error; } + compaction_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t isolate(executor_at && executor = executor_at{}, progress_at && progress = progress_at{}) + { + compaction_result_t result; + std::atomic pruned_edges; + auto disallow = [&](member_cref_t const & member) noexcept { + bool freed = member.key == free_key_; + pruned_edges += freed; + return freed; + }; + typed_->isolate(disallow, std::forward(executor), std::forward(progress)); + result.pruned_edges = pruned_edges; + return result; + } + + class values_proxy_t + { + index_dense_gt const * index_; + + public: + values_proxy_t(index_dense_gt const & index) noexcept + : index_(&index) + {} + byte_t const * operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } + byte_t const * operator[](member_citerator_t it) const noexcept + { + return index_->vectors_lookup_[get_slot(it)]; + } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t compact(executor_at && executor = executor_at{}, progress_at && progress = progress_at{}) + { + compaction_result_t result; + + std::vector new_vectors_lookup(vectors_lookup_.size()); + vectors_tape_allocator_t new_vectors_allocator; + + auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { + byte_t * new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); + byte_t * old_vector = vectors_lookup_[old_slot]; + std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); + new_vectors_lookup[new_slot] = new_vector; + }; + typed_->compact( + values_proxy_t{*this}, + metric_proxy_t{*this}, + track_slot_change, + std::forward(executor), + std::forward(progress)); + vectors_lookup_ = std::move(new_vectors_lookup); + vectors_tape_allocator_ = std::move(new_vectors_allocator); + return result; + } + + template < // + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + join_result_t join( // + index_dense_gt const & women, // + index_join_config_t config = {}, // + man_to_woman_at && man_to_woman = man_to_woman_at{}, // + woman_to_man_at && woman_to_man = woman_to_man_at{}, // + executor_at && executor = executor_at{}, // + progress_at && progress = progress_at{}) const + { + index_dense_gt const & men = *this; + return unum::usearch::join( // + *men.typed_, + *women.typed_, // + values_proxy_t{men}, + values_proxy_t{women}, // + metric_proxy_t{men}, + metric_proxy_t{women}, // + config, // + std::forward(man_to_woman), // + std::forward(woman_to_man), // + std::forward(executor), // + std::forward(progress)); + } + + struct clustering_result_t + { + error_t error{}; + std::size_t clusters{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + clustering_result_t failed(error_t message) noexcept + { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Implements clustering, classifying the given objects (vectors of member keys) + * into a given number of clusters. + * + * @param[in] queries_begin Iterator pointing to the first query. + * @param[in] queries_end Iterator pointing to the last query. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + * @param[in] config Configuration parameters for clustering. + * + * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. + * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. + */ + template < // + typename queries_iterator_at, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + clustering_result_t cluster( // + queries_iterator_at queries_begin, // + queries_iterator_at queries_end, // + index_dense_clustering_config_t config, // + vector_key_t * cluster_keys, // + distance_t * cluster_distances, // + executor_at && executor = executor_at{}, // + progress_at && progress = progress_at{}) + { + std::size_t const queries_count = queries_end - queries_begin; + + // Find the first level (top -> down) that has enough nodes to exceed `config.min_clusters`. + std::size_t level = max_level(); + if (config.min_clusters) + { + for (; level > 1; --level) + { + if (stats(level).nodes > config.min_clusters) + break; + } + } + else + level = 1, config.max_clusters = stats(1).nodes, config.min_clusters = 2; + + clustering_result_t result; + if (max_level() < 2) + return result.failed("Index too small to cluster!"); + + // A structure used to track the popularity of a specific cluster + struct cluster_t + { + vector_key_t centroid; + vector_key_t merged_into; + std::size_t popularity; + byte_t * vector; + }; + + auto centroid_id = [](cluster_t const & a, cluster_t const & b) { + return a.centroid < b.centroid; + }; + auto higher_popularity = [](cluster_t const & a, cluster_t const & b) { + return a.popularity > b.popularity; + }; + + std::atomic visited_members(0); + std::atomic computed_distances(0); + std::atomic atomic_error{nullptr}; + + using dynamic_allocator_traits_t = std::allocator_traits; + using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt clusters(queries_count); + if (!clusters) + return result.failed("Out of memory!"); + + map_to_clusters: + // Concurrently perform search until a certain depth + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + auto result = cluster(queries_begin[query_idx], level, thread_idx); + if (!result) + { + atomic_error = result.error.release(); + return false; + } + + cluster_keys[query_idx] = result.cluster.member.key; + cluster_distances[query_idx] = result.cluster.distance; + + // Export in case we need to refine afterwards + clusters[query_idx].centroid = result.cluster.member.key; + clusters[query_idx].vector = vectors_lookup_[result.cluster.member.slot]; + clusters[query_idx].merged_into = free_key(); + clusters[query_idx].popularity = 1; + + visited_members += result.visited_members; + computed_distances += result.computed_distances; + return true; + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Now once we have identified the closest clusters, + // we can try reducing their quantity, refining + std::sort(clusters.begin(), clusters.end(), centroid_id); + + // Transform into run-length encoding, computing the number of unique clusters + std::size_t unique_clusters = 0; + { + std::size_t last_idx = 0; + for (std::size_t current_idx = 1; current_idx != clusters.size(); ++current_idx) + { + if (clusters[last_idx].centroid == clusters[current_idx].centroid) + { + clusters[last_idx].popularity++; + } + else + { + last_idx++; + clusters[last_idx] = clusters[current_idx]; + } + } + unique_clusters = last_idx + 1; + } + + // In some cases the queries may be co-located, all mapping into the same cluster on that + // level. In that case we refine the granularity and dive deeper into clusters: + if (unique_clusters < config.min_clusters && level > 1) + { + level--; + goto map_to_clusters; + } + + std::sort(clusters.data(), clusters.data() + unique_clusters, higher_popularity); + + // If clusters are too numerous, merge the ones that are too close to each other. + std::size_t merge_cycles = 0; + merge_nearby_clusters: + if (unique_clusters > config.max_clusters) + { + cluster_t & merge_source = clusters[unique_clusters - 1]; + std::size_t merge_target_idx = 0; + distance_t merge_distance = std::numeric_limits::max(); + + for (std::size_t candidate_idx = 0; candidate_idx + 1 < unique_clusters; ++candidate_idx) + { + distance_t distance = metric_(merge_source.vector, clusters[candidate_idx].vector); + if (distance < merge_distance) + { + merge_distance = distance; + merge_target_idx = candidate_idx; + } + } + + merge_source.merged_into = clusters[merge_target_idx].centroid; + clusters[merge_target_idx].popularity += exchange(merge_source.popularity, 0); + + // The target object may have to be swapped a few times to get to optimal position. + while (merge_target_idx + && clusters[merge_target_idx - 1].popularity < clusters[merge_target_idx].popularity) + std::swap(clusters[merge_target_idx - 1], clusters[merge_target_idx]), --merge_target_idx; + + unique_clusters--; + merge_cycles++; + goto merge_nearby_clusters; + } + + // Replace evicted clusters + if (merge_cycles) + { + // Sort dropped clusters by name to accelerate future lookups + auto clusters_end = clusters.data() + config.max_clusters + merge_cycles; + std::sort(clusters.data(), clusters_end, centroid_id); + + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + vector_key_t & cluster_key = cluster_keys[query_idx]; + distance_t & cluster_distance = cluster_distances[query_idx]; + + // Recursively trace replacements of that cluster + while (true) + { + // To avoid implementing heterogeneous comparisons, lets wrap the `cluster_key` + cluster_t updated_cluster; + updated_cluster.centroid = cluster_key; + updated_cluster = *std::lower_bound(clusters.data(), clusters_end, updated_cluster, centroid_id); + if (updated_cluster.merged_into == free_key()) + break; + cluster_key = updated_cluster.merged_into; + } + + cluster_distance = distance_between(cluster_key, queries_begin[query_idx], thread_idx).mean; + return true; + }); + } + + result.computed_distances = computed_distances; + result.visited_members = visited_members; + + (void)progress; + return result; + } + +private: + struct thread_lock_t + { + index_dense_gt const & parent; + std::size_t thread_id; + bool engaged; + + ~thread_lock_t() + { + if (engaged) + parent.thread_unlock_(thread_id); + } + }; + + thread_lock_t thread_lock_(std::size_t thread_id) const + { + if (thread_id != any_thread()) + return {*this, thread_id, false}; + + available_threads_mutex_.lock(); + thread_id = available_threads_.back(); + available_threads_.pop_back(); + available_threads_mutex_.unlock(); + return {*this, thread_id, true}; + } + + void thread_unlock_(std::size_t thread_id) const + { + available_threads_mutex_.lock(); + available_threads_.push_back(thread_id); + available_threads_mutex_.unlock(); + } + + template + add_result_t add_( // + vector_key_t key, + scalar_at const * vector, // + std::size_t thread, + bool force_vector_copy, + cast_t const & cast) + { + if (!multi() && contains(key)) + return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + bool copy_vector = !config_.exclude_vectors || force_vector_copy; + byte_t const * vector_data = reinterpret_cast(vector); + { + byte_t * casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data, copy_vector = true; + } + + // Check if there are some removed entries, whose nodes we can reuse + compressed_slot_t free_slot = default_free_value(); + { + std::unique_lock lock(free_keys_mutex_); + free_keys_.try_pop(free_slot); + } + + // Perform the insertion or the update + bool reuse_node = free_slot != default_free_value(); + auto on_success = [&](member_ref_t member) { + unique_lock_t slot_lock(slot_lookup_mutex_); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + if (copy_vector) + { + if (!reuse_node) + vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); + std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); + } + else + vectors_lookup_[member.slot] = (byte_t *)vector_data; + }; + + index_update_config_t update_config; + update_config.thread = lock.thread_id; + update_config.expansion = config_.expansion_add; + + metric_proxy_t metric{*this}; + return reuse_node // + ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) + : typed_->add(key, vector_data, metric, update_config, on_success); + } + + template + search_result_t search_( // + scalar_at const * vector, + std::size_t wanted, // + predicate_at && predicate, + std::size_t thread, + size_t expansion, + bool exact, + cast_t const & cast) const + { + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const * vector_data = reinterpret_cast(vector); + { + byte_t * casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_search_config_t search_config; + search_config.thread = lock.thread_id; + search_config.expansion = expansion; + search_config.exact = exact; + + auto allow = [=](member_cref_t const & member) noexcept { + if (member.key == free_key_) + return false; + if constexpr (!is_dummy()) + return predicate(member); + return true; + }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + } + + template + cluster_result_t cluster_( // + scalar_at const * vector, + std::size_t level, // + std::size_t thread, + cast_t const & cast) const + { + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const * vector_data = reinterpret_cast(vector); + { + byte_t * casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_cluster_config_t cluster_config; + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + + auto allow = [=](member_cref_t const & member) noexcept { + return member.key != free_key_; + }; + return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); + } + + template + aggregated_distances_t distance_between_( // + vector_key_t key, + scalar_at const * vector, // + std::size_t thread, + cast_t const & cast) const + { + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const * vector_data = reinterpret_cast(vector); + { + byte_t * casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + aggregated_distances_t result; + if (key_range.first == key_range.second) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (key_range.first != key_range.second) + { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const * a_vector = vectors_lookup_[key_and_slot.slot]; + byte_t const * b_vector = vector_data; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++key_range.first; + } + + result.mean /= result.count; + return result; + } + + void reindex_keys_() + { + // Estimate number of entries first + std::size_t count_total = typed_->size(); + std::size_t count_removed = 0; + for (std::size_t i = 0; i != count_total; ++i) + { + member_cref_t member = typed_->at(i); + count_removed += member.key == free_key_; + } + + if (!count_removed && !config_.enable_key_lookups) + return; + + // Pull entries from the underlying `typed_` into either + // into `slot_lookup_`, or `free_keys_` if they are unused. + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.clear(); + if (config_.enable_key_lookups) + slot_lookup_.reserve(count_total - count_removed); + free_keys_.clear(); + free_keys_.reserve(count_removed); + for (std::size_t i = 0; i != typed_->size(); ++i) + { + member_cref_t member = typed_->at(i); + if (member.key == free_key_) + free_keys_.push(static_cast(i)); + else if (config_.enable_key_lookups) + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); + } + } + + template + std::size_t get_(vector_key_t key, scalar_at * reconstructed, std::size_t vectors_limit, cast_t const & cast) const + { + if (!multi()) + { + compressed_slot_t slot; + // Find the matching ID + { + shared_lock_t lock(slot_lookup_mutex_); + auto it = slot_lookup_.find(key_and_slot_t::any_slot(key)); + if (it == slot_lookup_.end()) + return false; + slot = (*it).slot; + } + // Export the entry + byte_t const * punned_vector = reinterpret_cast(vectors_lookup_[slot]); + bool casted = cast(punned_vector, dimensions(), (byte_t *)reconstructed); + if (!casted) + std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); + return true; + } + else + { + shared_lock_t lock(slot_lookup_mutex_); + auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + std::size_t count_exported = 0; + for (auto begin = equal_range_pair.first; + begin != equal_range_pair.second && count_exported != vectors_limit; + ++begin, ++count_exported) + { + // + compressed_slot_t slot = (*begin).slot; + byte_t const * punned_vector = reinterpret_cast(vectors_lookup_[slot]); + byte_t * reconstructed_vector = (byte_t *)reconstructed + metric_.bytes_per_vector() * count_exported; + bool casted = cast(punned_vector, dimensions(), reconstructed_vector); + if (!casted) + std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); + } + return count_exported; + } + } + + template + static casts_t make_casts_() + { + casts_t result; + + result.from_b1x8 = cast_gt{}; + result.from_i8 = cast_gt{}; + result.from_f16 = cast_gt{}; + result.from_f32 = cast_gt{}; + result.from_f64 = cast_gt{}; + + result.to_b1x8 = cast_gt{}; + result.to_i8 = cast_gt{}; + result.to_f16 = cast_gt{}; + result.to_f32 = cast_gt{}; + result.to_f64 = cast_gt{}; + + return result; + } + + static casts_t make_casts_(scalar_kind_t scalar_kind) + { + switch (scalar_kind) + { + case scalar_kind_t::f64_k: + return make_casts_(); + case scalar_kind_t::f32_k: + return make_casts_(); + case scalar_kind_t::f16_k: + return make_casts_(); + case scalar_kind_t::i8_k: + return make_casts_(); + case scalar_kind_t::b1x8_k: + return make_casts_(); + default: + return {}; + } + } +}; + +using index_dense_t = index_dense_gt<>; +using index_dense_big_t = index_dense_gt; + +/** + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::first keys to ::second. + * @param[inout] woman_to_man Container to map ::second keys to ::first. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ +template < // + + typename men_key_at, // + typename women_key_at, // + typename men_slot_at, // + typename women_slot_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > +static join_result_t join( // + index_dense_gt const & men, // + index_dense_gt const & women, // + + index_join_config_t config = {}, // + man_to_woman_at && man_to_woman = man_to_woman_at{}, // + woman_to_man_at && woman_to_man = woman_to_man_at{}, // + executor_at && executor = executor_at{}, // + progress_at && progress = progress_at{}) noexcept +{ + return men.join( // + women, + config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); +} + +} // namespace usearch +} // namespace unum + +// NOLINTEND(readability-*,google-*,modernize-use-auto) diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h new file mode 100644 index 00000000000..abf481411f7 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h @@ -0,0 +1,30 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +using ANNQueryInfoPtr = std::shared_ptr; + +class VectorIndex; +using VectorIndexPtr = std::shared_ptr; + +class VectorIndexCache; +using VectorIndexCachePtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ReadUtil.cpp b/dbms/src/Storages/DeltaMerge/ReadUtil.cpp index c4ca69698b9..88cb634b201 100644 --- a/dbms/src/Storages/DeltaMerge/ReadUtil.cpp +++ b/dbms/src/Storages/DeltaMerge/ReadUtil.cpp @@ -46,6 +46,37 @@ std::pair readBlock(SkippableBlockInputStreamPtr & stable, Skippabl } } +std::pair readBlockWithReturnFilter( + SkippableBlockInputStreamPtr & stable, + SkippableBlockInputStreamPtr & delta, + FilterPtr & filter) +{ + if (stable == nullptr && delta == nullptr) + { + return {{}, false}; + } + + if (stable == nullptr) + { + return {delta->read(filter, true), true}; + } + + auto block = stable->read(filter, true); + if (block) + { + return {block, false}; + } + else + { + stable = nullptr; + if (delta != nullptr) + { + block = delta->read(filter, true); + } + return {block, true}; + } +} + size_t skipBlock(SkippableBlockInputStreamPtr & stable, SkippableBlockInputStreamPtr & delta) { if (stable == nullptr && delta == nullptr) diff --git a/dbms/src/Storages/DeltaMerge/ReadUtil.h b/dbms/src/Storages/DeltaMerge/ReadUtil.h index b183dc809ce..378c6a96820 100644 --- a/dbms/src/Storages/DeltaMerge/ReadUtil.h +++ b/dbms/src/Storages/DeltaMerge/ReadUtil.h @@ -21,12 +21,20 @@ namespace DB::DM /** Read the next block. * Read from the stable first, then read from the delta. - * + * * Return: * the block and a flag indicating whether the block is from the delta. */ std::pair readBlock(SkippableBlockInputStreamPtr & stable, SkippableBlockInputStreamPtr & delta); +/** + * Like readBlock, but it forces the underlying stream to return the filter if any. + */ +std::pair readBlockWithReturnFilter( + SkippableBlockInputStreamPtr & stable, + SkippableBlockInputStreamPtr & delta, + FilterPtr & filter); + /** Skip the next block. * Return the number of rows of the next block. */ @@ -34,7 +42,7 @@ size_t skipBlock(SkippableBlockInputStreamPtr & stable, SkippableBlockInputStrea /** Read the next block with filter. * Read from the stable first, then read from the delta. - * + * * Return: * The block containing only the rows that pass the filter and a flag indicating whether the block is from the delta. */ diff --git a/dbms/src/Storages/DeltaMerge/ScanContext.h b/dbms/src/Storages/DeltaMerge/ScanContext.h index 78321dd8737..7659e04cdfc 100644 --- a/dbms/src/Storages/DeltaMerge/ScanContext.h +++ b/dbms/src/Storages/DeltaMerge/ScanContext.h @@ -84,6 +84,15 @@ class ScanContext // Building bitmap std::atomic build_bitmap_time_ns{0}; + std::atomic total_vector_idx_load_from_disk{0}; + std::atomic total_vector_idx_load_from_cache{0}; + std::atomic total_vector_idx_load_time_ms{0}; + std::atomic total_vector_idx_search_time_ms{0}; + std::atomic total_vector_idx_search_visited_nodes{0}; + std::atomic total_vector_idx_search_discarded_nodes{0}; + std::atomic total_vector_idx_read_vec_time_ms{0}; + std::atomic total_vector_idx_read_others_time_ms{0}; + const String resource_group_name; explicit ScanContext(const String & name = "") @@ -129,6 +138,15 @@ class ScanContext tiflash_scan_context_pb.max_remote_stream_ms() * 1000000); deserializeRegionNumberOfInstance(tiflash_scan_context_pb); + + total_vector_idx_load_from_disk = tiflash_scan_context_pb.total_vector_idx_load_from_disk(); + total_vector_idx_load_from_cache = tiflash_scan_context_pb.total_vector_idx_load_from_cache(); + total_vector_idx_load_time_ms = tiflash_scan_context_pb.total_vector_idx_load_time_ms(); + total_vector_idx_search_time_ms = tiflash_scan_context_pb.total_vector_idx_search_time_ms(); + total_vector_idx_search_visited_nodes = tiflash_scan_context_pb.total_vector_idx_search_visited_nodes(); + total_vector_idx_search_discarded_nodes = tiflash_scan_context_pb.total_vector_idx_search_discarded_nodes(); + total_vector_idx_read_vec_time_ms = tiflash_scan_context_pb.total_vector_idx_read_vec_time_ms(); + total_vector_idx_read_others_time_ms = tiflash_scan_context_pb.total_vector_idx_read_others_time_ms(); } tipb::TiFlashScanContext serialize() @@ -171,6 +189,15 @@ class ScanContext serializeRegionNumOfInstance(tiflash_scan_context_pb); + tiflash_scan_context_pb.set_total_vector_idx_load_from_disk(total_vector_idx_load_from_disk); + tiflash_scan_context_pb.set_total_vector_idx_load_from_cache(total_vector_idx_load_from_cache); + tiflash_scan_context_pb.set_total_vector_idx_load_time_ms(total_vector_idx_load_time_ms); + tiflash_scan_context_pb.set_total_vector_idx_search_time_ms(total_vector_idx_search_time_ms); + tiflash_scan_context_pb.set_total_vector_idx_search_visited_nodes(total_vector_idx_search_visited_nodes); + tiflash_scan_context_pb.set_total_vector_idx_search_discarded_nodes(total_vector_idx_search_discarded_nodes); + tiflash_scan_context_pb.set_total_vector_idx_read_vec_time_ms(total_vector_idx_read_vec_time_ms); + tiflash_scan_context_pb.set_total_vector_idx_read_others_time_ms(total_vector_idx_read_others_time_ms); + return tiflash_scan_context_pb; } @@ -217,6 +244,15 @@ class ScanContext other.remote_max_stream_cost_ns); mergeRegionNumberOfInstance(other); + + total_vector_idx_load_from_disk += other.total_vector_idx_load_from_disk; + total_vector_idx_load_from_cache += other.total_vector_idx_load_from_cache; + total_vector_idx_load_time_ms += other.total_vector_idx_load_time_ms; + total_vector_idx_search_time_ms += other.total_vector_idx_search_time_ms; + total_vector_idx_search_visited_nodes += other.total_vector_idx_search_visited_nodes; + total_vector_idx_search_discarded_nodes += other.total_vector_idx_search_discarded_nodes; + total_vector_idx_read_vec_time_ms += other.total_vector_idx_read_vec_time_ms; + total_vector_idx_read_others_time_ms += other.total_vector_idx_read_others_time_ms; } void merge(const tipb::TiFlashScanContext & other) @@ -258,6 +294,17 @@ class ScanContext other.max_remote_stream_ms() * 1000000); mergeRegionNumberOfInstance(other); + disagg_read_cache_hit_size += other.disagg_read_cache_hit_bytes(); + disagg_read_cache_miss_size += other.disagg_read_cache_miss_bytes(); + + total_vector_idx_load_from_disk += other.total_vector_idx_load_from_disk(); + total_vector_idx_load_from_cache += other.total_vector_idx_load_from_cache(); + total_vector_idx_load_time_ms += other.total_vector_idx_load_time_ms(); + total_vector_idx_search_time_ms += other.total_vector_idx_search_time_ms(); + total_vector_idx_search_visited_nodes += other.total_vector_idx_search_visited_nodes(); + total_vector_idx_search_discarded_nodes += other.total_vector_idx_search_discarded_nodes(); + total_vector_idx_read_vec_time_ms += other.total_vector_idx_read_vec_time_ms(); + total_vector_idx_read_others_time_ms += other.total_vector_idx_read_others_time_ms(); } String toJson() const; diff --git a/dbms/src/Storages/DeltaMerge/Segment.cpp b/dbms/src/Storages/DeltaMerge/Segment.cpp index 89937f04794..1e4e084bde3 100644 --- a/dbms/src/Storages/DeltaMerge/Segment.cpp +++ b/dbms/src/Storages/DeltaMerge/Segment.cpp @@ -3035,7 +3035,10 @@ BlockInputStreamPtr Segment::getBitmapFilterInputStream( enable_handle_clean_read, ReadTag::Query, is_fast_scan, - enable_del_clean_read); + enable_del_clean_read, + /* read_packs */ {}, + /* need_row_id */ false, + /* bitmap_filter */ bitmap_filter); auto columns_to_read_ptr = std::make_shared(columns_to_read); SkippableBlockInputStreamPtr delta_stream = std::make_shared( diff --git a/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h b/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h index a0a5d3ffb76..4992ca8e862 100644 --- a/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h @@ -177,12 +177,18 @@ class ConcatSkippableBlockInputStream : public SkippableBlockInputStream } Block read() override + { + FilterPtr filter = nullptr; + return read(filter, false); + } + + Block read(FilterPtr & res_filter, bool return_filter) override { Block res; while (current_stream != children.end()) { - res = (*current_stream)->read(); + res = (*current_stream)->read(res_filter, return_filter); if (res) { diff --git a/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp b/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp index 96e5030c2e4..2545ffe055f 100644 --- a/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp +++ b/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp @@ -450,7 +450,8 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( bool is_fast_scan, bool enable_del_clean_read, const std::vector & read_packs, - bool need_row_id) + bool need_row_id, + BitmapFilterPtr bitmap_filter) { LOG_DEBUG( log, @@ -463,6 +464,9 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( std::vector rows; streams.reserve(stable->files.size()); rows.reserve(stable->files.size()); + + size_t last_rows = 0; + for (size_t i = 0; i < stable->files.size(); i++) { DMFileBlockInputStreamBuilder builder(context.db_context); @@ -473,7 +477,15 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( .setRowsThreshold(expected_block_size) .setReadPacks(read_packs.size() > i ? read_packs[i] : nullptr) .setReadTag(read_tag); - streams.push_back(builder.build(stable->files[i], read_columns, rowkey_ranges, context.scan_context)); + + if (bitmap_filter) + { + builder = builder.setBitmapFilter( + BitmapFilterView(bitmap_filter, last_rows, last_rows + stable->files[i]->getRows())); + last_rows += stable->files[i]->getRows(); + } + + streams.push_back(builder.build2(stable->files[i], read_columns, rowkey_ranges, context.scan_context)); rows.push_back(stable->files[i]->getRows()); } if (need_row_id) diff --git a/dbms/src/Storages/DeltaMerge/StableValueSpace.h b/dbms/src/Storages/DeltaMerge/StableValueSpace.h index 14077ef712b..37ad287e0df 100644 --- a/dbms/src/Storages/DeltaMerge/StableValueSpace.h +++ b/dbms/src/Storages/DeltaMerge/StableValueSpace.h @@ -232,7 +232,8 @@ class StableValueSpace : public std::enable_shared_from_this bool is_fast_scan = false, bool enable_del_clean_read = false, const std::vector & read_packs = {}, - bool need_row_id = false); + bool need_row_id = false, + BitmapFilterPtr bitmap_filter = nullptr); RowsAndBytes getApproxRowsAndBytes(const DMContext & context, const RowKeyRange & range) const; diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_minmax_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_minmax_index.cpp index 8cb836db36f..b20c1fd4ead 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_minmax_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_minmax_index.cpp @@ -2183,6 +2183,7 @@ try ColumnInfos column_infos = {a, b}; auto dag_query = std::make_unique( filters, + tipb::ANNQueryInfo{}, pushed_down_filters, // Not care now column_infos, std::vector{}, diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_storage_delta_merge.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_storage_delta_merge.cpp index b80680a8a34..21d5fd5f060 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_storage_delta_merge.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_storage_delta_merge.cpp @@ -127,6 +127,7 @@ try const google::protobuf::RepeatedPtrField pushed_down_filters{}; query_info.dag_query = std::make_unique( google::protobuf::RepeatedPtrField(), + tipb::ANNQueryInfo{}, pushed_down_filters, // Not care now std::vector{}, // Not care now std::vector{}, @@ -677,6 +678,7 @@ try const google::protobuf::RepeatedPtrField pushed_down_filters{}; query_info.dag_query = std::make_unique( google::protobuf::RepeatedPtrField(), + tipb::ANNQueryInfo{}, pushed_down_filters, // Not care now std::vector{}, // Not care now std::vector{}, @@ -790,6 +792,7 @@ try const google::protobuf::RepeatedPtrField pushed_down_filters{}; query_info.dag_query = std::make_unique( google::protobuf::RepeatedPtrField(), + tipb::ANNQueryInfo{}, pushed_down_filters, // Not care now std::vector{}, // Not care now std::vector{}, diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp new file mode 100644 index 00000000000..ba9aa5fab11 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp @@ -0,0 +1,1111 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS,n +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM::tests +{ + +class VectorIndexTestUtils +{ +public: + const ColumnID vec_column_id = 100; + const String vec_column_name = "vec"; + + /// Create a column with values like [1], [2], [3], ... + /// Each value is a VectorFloat32 with exactly one dimension. + static ColumnWithTypeAndName colInt64(std::string_view sequence, const String & name = "", Int64 column_id = 0) + { + auto data = genSequence(sequence); + return createColumn(data, name, column_id); + } + + static ColumnWithTypeAndName colVecFloat32(std::string_view sequence, const String & name = "", Int64 column_id = 0) + { + auto data = genSequence(sequence); + std::vector data_in_array; + for (auto & v : data) + { + Array vec; + vec.push_back(static_cast(v)); + data_in_array.push_back(vec); + } + return createVecFloat32Column(data_in_array, name, column_id); + } + + static String encodeVectorFloat32(const std::vector & vec) + { + WriteBufferFromOwnString wb; + Array arr; + for (const auto & v : vec) + arr.push_back(static_cast(v)); + EncodeVectorFloat32(arr, wb); + return wb.str(); + } +}; + +class VectorIndexDMFileTest + : public VectorIndexTestUtils + , public DB::base::TiFlashStorageTestBasic + , public testing::WithParamInterface +{ +public: + void SetUp() override + { + TiFlashStorageTestBasic::SetUp(); + + parent_path = TiFlashStorageTestBasic::getTemporaryPath(); + path_pool = std::make_shared( + db_context->getPathPool().withTable("test", "VectorIndexDMFileTest", false)); + storage_pool = std::make_shared(*db_context, NullspaceID, /*ns_id*/ 100, *path_pool, "test.t1"); + dm_file = DMFile::create( + 1, + parent_path, + std::make_optional(), + 128 * 1024, + 16 * 1024 * 1024, + DMFileFormat::V3); + + DB::tests::TiFlashTestEnv::disableS3Config(); + + reload(); + } + + // Update dm_context. + void reload() + { + TiFlashStorageTestBasic::reload(); + + *path_pool = db_context->getPathPool().withTable("test", "t1", false); + dm_context = std::make_unique( + *db_context, + path_pool, + storage_pool, + /*min_version_*/ 0, + NullspaceID, + /*physical_table_id*/ 100, + false, + 1, + db_context->getSettingsRef()); + } + + DMFilePtr restoreDMFile() + { + auto file_id = dm_file->fileId(); + auto page_id = dm_file->pageId(); + auto file_provider = dbContext().getFileProvider(); + return DMFile::restore(file_provider, file_id, page_id, parent_path, DMFile::ReadMetaMode::all()); + } + + Context & dbContext() { return *db_context; } + +protected: + std::unique_ptr dm_context{}; + /// all these var live as ref in dm_context + std::shared_ptr path_pool{}; + std::shared_ptr storage_pool{}; + +protected: + String parent_path; + DMFilePtr dm_file = nullptr; + +public: + VectorIndexDMFileTest() { test_only_vec_column = GetParam(); } + +protected: + // DMFile has different logic when there is only vec column. + // So we test it independently. + bool test_only_vec_column = false; + + ColumnsWithTypeAndName createColumnData(const ColumnsWithTypeAndName & columns) + { + if (!test_only_vec_column) + return columns; + + // In test_only_vec_column mode, only contains the Array column. + for (const auto & col : columns) + { + if (col.type->getName() == "Array(Float32)") + return {col}; + } + + RUNTIME_CHECK(false); + } + + Strings createColumnNames() + { + if (!test_only_vec_column) + return {DMTestEnv::pk_name, vec_column_name}; + + // In test_only_vec_column mode, only contains the Array column. + return {vec_column_name}; + } +}; + +INSTANTIATE_TEST_CASE_P(VectorIndex, VectorIndexDMFileTest, testing::Bool()); + +TEST_P(VectorIndexDMFileTest, OnePack) +try +{ + auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); + auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ + .kind = TiDB::VectorIndexKind::HNSW, + .dimension = 3, + .distance_metric = TiDB::DistanceMetric::L2, + }); + cols->emplace_back(vec_cd); + + ColumnDefines read_cols = *cols; + if (test_only_vec_column) + read_cols = {vec_cd}; + + // Prepare DMFile + { + Block block = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 3); + block.insert( + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}, vec_cd.name, vec_cd.id)); + auto stream = std::make_shared(dbContext(), dm_file, *cols); + stream->writePrefix(); + stream->write(block, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->writeSuffix(); + } + + dm_file = restoreDMFile(); + + // Read with exact match + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.5})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({2}), + createVecFloat32Column({{1.0, 2.0, 3.5}}), + })); + } + + // Read with approximate match + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({2}), + createVecFloat32Column({{1.0, 2.0, 3.5}}), + })); + } + + // Read multiple rows + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(2); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 2}), + createVecFloat32Column({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.5}}), + })); + } + + // Read with MVCC filter + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + auto bitmap_filter = std::make_shared(3, true); + bitmap_filter->set(/* start */ 2, /* limit */ 1, false); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0}), + createVecFloat32Column({{1.0, 2.0, 3.0}}), + })); + } + + // Query Top K = 0: the pack should be filtered out + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(0); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({}), + createVecFloat32Column({}), + })); + } + + // Query Top K > rows + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(10); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 1, 2}), + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}), + })); + } + + // Illegal ANNQueryInfo: Ref Vector'dimension is different + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(10); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + + try + { + stream->readPrefix(); + stream->read(); + FAIL(); + } + catch (const DB::Exception & ex) + { + ASSERT_STREQ("Query vector size 1 does not match index dimensions 3", ex.message().c_str()); + } + catch (...) + { + FAIL(); + } + } + + // Illegal ANNQueryInfo: Referencing a non-existed column. This simply cause vector index not used. + // The query will not fail, because ANNQueryInfo is passed globally in the whole read path. + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(5); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 1, 2}), + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}), + })); + } + + // Illegal ANNQueryInfo: Different distance metric. + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::COSINE); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + + try + { + stream->readPrefix(); + stream->read(); + FAIL(); + } + catch (const DB::Exception & ex) + { + ASSERT_STREQ("Query distance metric Cosine does not match index distance metric L2", ex.message().c_str()); + } + catch (...) + { + FAIL(); + } + } + + // Illegal ANNQueryInfo: The column exists but is not a vector column. + // Currently the query is fine and ANNQueryInfo is discarded, because we discovered that this column + // does not have index at all. + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(EXTRA_HANDLE_COLUMN_ID); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 1, 2}), + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}), + })); + } +} +CATCH + +TEST_P(VectorIndexDMFileTest, MultiPacks) +try +{ + auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); + auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ + .kind = TiDB::VectorIndexKind::HNSW, + .dimension = 3, + .distance_metric = TiDB::DistanceMetric::L2, + }); + cols->emplace_back(vec_cd); + + ColumnDefines read_cols = *cols; + if (test_only_vec_column) + read_cols = {vec_cd}; + + // Prepare DMFile + { + Block block1 = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 3); + block1.insert( + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}, vec_cd.name, vec_cd.id)); + + Block block2 = DMTestEnv::prepareSimpleWriteBlockWithNullable(3, 6); + block2.insert( + createVecFloat32Column({{5.0, 5.0, 5.0}, {5.0, 5.0, 7.0}, {0.0, 0.0, 0.0}}, vec_cd.name, vec_cd.id)); + + auto stream = std::make_shared(dbContext(), dm_file, *cols); + stream->writePrefix(); + stream->write(block1, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->write(block2, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->writeSuffix(); + } + + dm_file = restoreDMFile(); + + // Pack #0 is filtered out according to VecIndex + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({5.0, 5.0, 5.5})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({3}), + createVecFloat32Column({{5.0, 5.0, 5.0}}), + })); + } + + // Pack #1 is filtered out according to VecIndex + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.0})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0}), + createVecFloat32Column({{1.0, 2.0, 3.0}}), + })); + } + + // Both packs are reserved + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(2); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({0.0, 0.0, 0.0})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({1, 5}), + createVecFloat32Column({{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}), + })); + } + + // Pack Filter + MVCC (the matching row #5 is marked as filtered out by MVCC) + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(2); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({0.0, 0.0, 0.0})); + + auto bitmap_filter = std::make_shared(6, true); + bitmap_filter->set(/* start */ 5, /* limit */ 1, false); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 6)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 1}), + createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}}), + })); + } +} +CATCH + +TEST_P(VectorIndexDMFileTest, WithPackFilter) +try +{ + auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); + auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ + .kind = TiDB::VectorIndexKind::HNSW, + .dimension = 1, + .distance_metric = TiDB::DistanceMetric::L2, + }); + cols->emplace_back(vec_cd); + + ColumnDefines read_cols = *cols; + if (test_only_vec_column) + read_cols = {vec_cd}; + + // Prepare DMFile + { + Block block1 = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 3); + block1.insert(colVecFloat32("[0, 3)", vec_cd.name, vec_cd.id)); + + Block block2 = DMTestEnv::prepareSimpleWriteBlockWithNullable(3, 6); + block2.insert(colVecFloat32("[3, 6)", vec_cd.name, vec_cd.id)); + + Block block3 = DMTestEnv::prepareSimpleWriteBlockWithNullable(6, 9); + block3.insert(colVecFloat32("[6, 9)", vec_cd.name, vec_cd.id)); + + auto stream = std::make_shared(dbContext(), dm_file, *cols); + stream->writePrefix(); + stream->write(block1, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->write(block2, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->write(block3, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->writeSuffix(); + } + + dm_file = restoreDMFile(); + + // Pack Filter using RowKeyRange + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(1); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({8.0})); + + // This row key range will cause pack#0 and pack#1 reserved, and pack#2 filtered out. + auto row_key_ranges = RowKeyRanges{RowKeyRange::fromHandleRange(HandleRange(0, 5))}; + + auto bitmap_filter = std::make_shared(9, false); + bitmap_filter->set(0, 6); // 0~6 rows are valid, 6~9 rows are invalid due to pack filter. + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + .build2(dm_file, read_cols, row_key_ranges, std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({5}), + createVecFloat32Column({{5.0}}), + })); + + // TopK=4 + ann_query_info->set_top_k(4); + builder = DMFileBlockInputStreamBuilder(dbContext()); + stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + .build2(dm_file, read_cols, row_key_ranges, std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({2, 3, 4, 5}), + createVecFloat32Column({{2.0}, {3.0}, {4.0}, {5.0}}), + })); + } + + // Pack Filter + Bitmap Filter + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(3); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({8.0})); + + // This row key range will cause pack#0 and pack#1 reserved, and pack#2 filtered out. + auto row_key_ranges = RowKeyRanges{RowKeyRange::fromHandleRange(HandleRange(0, 5))}; + + // Valid rows are 0, 1, , 3, 4 + auto bitmap_filter = std::make_shared(9, false); + bitmap_filter->set(0, 2); + bitmap_filter->set(3, 2); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + .build2(dm_file, read_cols, row_key_ranges, std::make_shared()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({1, 3, 4}), + createVecFloat32Column({{1.0}, {3.0}, {4.0}}), + })); + } +} +CATCH + +class VectorIndexSegmentTestBase + : public VectorIndexTestUtils + , public SegmentTestBasic +{ +public: + BlockInputStreamPtr annQuery( + PageIdU64 segment_id, + Int64 begin, + Int64 end, + ColumnDefines columns_to_read, + UInt32 top_k, + const std::vector & ref_vec) + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_column_id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(top_k); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32(ref_vec)); + return read(segment_id, begin, end, columns_to_read, ann_query_info); + } + + BlockInputStreamPtr annQuery( + PageIdU64 segment_id, + ColumnDefines columns_to_read, + UInt32 top_k, + const std::vector & ref_vec) + { + auto [segment_start_key, segment_end_key] = getSegmentKeyRange(segment_id); + return annQuery(segment_id, segment_start_key, segment_end_key, columns_to_read, top_k, ref_vec); + } + + BlockInputStreamPtr read( + PageIdU64 segment_id, + Int64 begin, + Int64 end, + ColumnDefines columns_to_read, + ANNQueryInfoPtr ann_query) + { + auto range = buildRowKeyRange(begin, end); + auto [segment, snapshot] = getSegmentForRead(segment_id); + auto stream = segment->getBitmapFilterInputStream( + *dm_context, + columns_to_read, + snapshot, + {range}, + std::make_shared(wrapWithANNQueryInfo({}, ann_query)), + std::numeric_limits::max(), + DEFAULT_BLOCK_SIZE, + DEFAULT_BLOCK_SIZE); + return stream; + } + + ColumnDefine cdPK() { return getExtraHandleColumnDefine(options.is_common_handle); } + + ColumnDefine cdVec() + { + // When used in read, no need to assign vector_index. + return ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + } + +protected: + Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) override + { + auto block = SegmentTestBasic::prepareWriteBlockImpl(start_key, end_key, is_deleted); + block.insert(colVecFloat32(fmt::format("[{}, {})", start_key, end_key), vec_column_name, vec_column_id)); + return block; + } + + void prepareColumns(const ColumnDefinesPtr & columns) override + { + auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ + .kind = TiDB::VectorIndexKind::HNSW, + .dimension = 1, + .distance_metric = TiDB::DistanceMetric::L2, + }); + columns->emplace_back(vec_cd); + } + +protected: + // DMFile has different logic when there is only vec column. + // So we test it independently. + bool test_only_vec_column = false; + int pack_size = 10; + + ColumnsWithTypeAndName createColumnData(const ColumnsWithTypeAndName & columns) + { + if (!test_only_vec_column) + return columns; + + // In test_only_vec_column mode, only contains the Array column. + for (const auto & col : columns) + { + if (col.type->getName() == "Array(Float32)") + return {col}; + } + + RUNTIME_CHECK(false); + } + + virtual Strings createColumnNames() + { + if (!test_only_vec_column) + return {DMTestEnv::pk_name, vec_column_name}; + + // In test_only_vec_column mode, only contains the Array column. + return {vec_column_name}; + } + + virtual ColumnDefines createQueryColumns() + { + if (!test_only_vec_column) + return {cdPK(), cdVec()}; + + return {cdVec()}; + } + + inline void assertStreamOut(BlockInputStreamPtr stream, std::string_view expected_sequence) + { + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + colInt64(expected_sequence), + colVecFloat32(expected_sequence), + })); + } +}; + +class VectorIndexSegmentTest1 + : public VectorIndexSegmentTestBase + , public testing::WithParamInterface +{ +public: + VectorIndexSegmentTest1() { test_only_vec_column = GetParam(); } +}; + +INSTANTIATE_TEST_CASE_P( // + VectorIndex, + VectorIndexSegmentTest1, + /* vec_only */ ::testing::Bool()); + +class VectorIndexSegmentTest2 + : public VectorIndexSegmentTestBase + , public testing::WithParamInterface> +{ +public: + VectorIndexSegmentTest2() { std::tie(test_only_vec_column, pack_size) = GetParam(); } +}; + +INSTANTIATE_TEST_CASE_P( // + VectorIndex, + VectorIndexSegmentTest2, + ::testing::Combine( // + /* vec_only */ ::testing::Bool(), + /* pack_size */ ::testing::Values(1, 2, 3, 4, 5))); + +TEST_P(VectorIndexSegmentTest1, DataInCFInMemory) +try +{ + // Vector in memory will not filter by ANNQuery at all. + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0); + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 10); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)|[10, 15)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ -10); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)|[10, 15)|[-10, -5)"); +} +CATCH + +TEST_P(VectorIndexSegmentTest1, DataInCFTiny) +try +{ + // Vector in column file tiny will not filter by ANNQuery at all. + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); + + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ -10); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)|[-10, -5)"); + + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)|[-10, -5)"); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 12, /* at */ -10); + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[2, 5)|[-10, 2)"); +} +CATCH + +TEST_P(VectorIndexSegmentTest1, DataInCFBig) +try +{ + // Vector in column file big will not filter by ANNQuery at all. + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0, /* clear */ false); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[0, 5)"); +} +CATCH + +TEST_P(VectorIndexSegmentTest2, DataInStable) +try +{ + db_context->getSettingsRef().dt_segment_stable_pack_rows = pack_size; + reloadDMContext(); + + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0, /* clear */ false); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + mergeSegmentDelta(DELTA_MERGE_FIRST_SEGMENT_ID); + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[4, 5)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 3, {100.0}); + assertStreamOut(stream, "[2, 5)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {1.1}); + assertStreamOut(stream, "[1, 2)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 2, {1.1}); + assertStreamOut(stream, "[1, 3)"); +} +CATCH + +TEST_P(VectorIndexSegmentTest2, DataInStableAndDelta) +try +{ + db_context->getSettingsRef().dt_segment_stable_pack_rows = pack_size; + reloadDMContext(); + + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0, /* clear */ false); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + mergeSegmentDelta(DELTA_MERGE_FIRST_SEGMENT_ID); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 20); + + // ANNQuery will be only effective to Stable layer. All delta data will be returned. + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[4, 5)|[20, 30)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 2, {10.0}); + assertStreamOut(stream, "[3, 5)|[20, 30)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 5, {10.0}); + assertStreamOut(stream, "[0, 5)|[20, 30)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 10, {10.0}); + assertStreamOut(stream, "[0, 5)|[20, 30)"); +} +CATCH + +TEST_P(VectorIndexSegmentTest2, SegmentSplit) +try +{ + db_context->getSettingsRef().dt_segment_stable_pack_rows = pack_size; + reloadDMContext(); + + // Stable: [0, 10), [20, 30) + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 0, /* clear */ false); + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 20, /* clear */ false); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + mergeSegmentDelta(DELTA_MERGE_FIRST_SEGMENT_ID); + + // Delta: [12, 18), [50, 60) + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 6, /* at */ 12); + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 50); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + + auto right_seg_id = splitSegmentAt(DELTA_MERGE_FIRST_SEGMENT_ID, 15, Segment::SplitMode::Logical); + RUNTIME_CHECK(right_seg_id.has_value()); + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[9, 10)|[12, 15)"); + + stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 100, {100.0}); + assertStreamOut(stream, "[0, 10)|[12, 15)"); + + stream = annQuery(right_seg_id.value(), createQueryColumns(), 1, {100.0}); + assertStreamOut(stream, "[29, 30)|[15, 18)|[50, 60)"); + + stream = annQuery(right_seg_id.value(), createQueryColumns(), 100, {100.0}); + assertStreamOut(stream, "[20, 30)|[15, 18)|[50, 60)"); +} +CATCH + +class VectorIndexSegmentExtraColumnTest + : public VectorIndexSegmentTestBase + , public testing::WithParamInterface> +{ +public: + VectorIndexSegmentExtraColumnTest() { std::tie(test_only_vec_column, pack_size) = GetParam(); } + +protected: + const String extra_column_name = "extra"; + const ColumnID extra_column_id = 500; + + ColumnDefine cdExtra() + { + // When used in read, no need to assign vector_index. + return ColumnDefine(extra_column_id, extra_column_name, tests::typeFromString("Int64")); + } + + Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) override + { + auto block = VectorIndexSegmentTestBase::prepareWriteBlockImpl(start_key, end_key, is_deleted); + block.insert( + colInt64(fmt::format("[{}, {})", start_key + 1000, end_key + 1000), extra_column_name, extra_column_id)); + return block; + } + + void prepareColumns(const ColumnDefinesPtr & columns) override + { + VectorIndexSegmentTestBase::prepareColumns(columns); + columns->emplace_back(cdExtra()); + } + + Strings createColumnNames() override + { + if (!test_only_vec_column) + return {DMTestEnv::pk_name, vec_column_name, extra_column_name}; + + // In test_only_vec_column mode, only contains the Array column. + return {vec_column_name}; + } + + ColumnDefines createQueryColumns() override + { + if (!test_only_vec_column) + return {cdPK(), cdVec(), cdExtra()}; + + return {cdVec()}; + } +}; + +INSTANTIATE_TEST_CASE_P( + VectorIndex, + VectorIndexSegmentExtraColumnTest, + ::testing::Combine( // + /* vec_only */ ::testing::Bool(), + /* pack_size */ ::testing::Values(1 /*, 2, 3, 4, 5*/))); + +TEST_P(VectorIndexSegmentExtraColumnTest, DataInStableAndDelta) +try +{ + db_context->getSettingsRef().dt_segment_stable_pack_rows = pack_size; + reloadDMContext(); + + ingestDTFileIntoDelta(DELTA_MERGE_FIRST_SEGMENT_ID, 5, /* at */ 0, /* clear */ false); + flushSegmentCache(DELTA_MERGE_FIRST_SEGMENT_ID); + mergeSegmentDelta(DELTA_MERGE_FIRST_SEGMENT_ID); + + writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 20); + + auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + colInt64("[4, 5)|[20, 30)"), + colVecFloat32("[4, 5)|[20, 30)"), + colInt64("[1004, 1005)|[1020, 1030)"), + })); +} +CATCH + +} // namespace DB::DM::tests diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp index 15fc3a2f481..b1d1e21960a 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp @@ -349,7 +349,7 @@ std::pair SegmentTestBasic::getSegmentKeyRange(PageIdU64 segment_i return {start_key, end_key}; } -Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool is_deleted) +Block SegmentTestBasic::prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) { RUNTIME_CHECK(start_key <= end_key); if (end_key == start_key) @@ -369,6 +369,11 @@ Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool i is_deleted); } +Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool is_deleted) +{ + return prepareWriteBlockImpl(start_key, end_key, is_deleted); +} + Block sortvstackBlocks(std::vector && blocks) { auto accumulated_block = vstackBlocks(std::move(blocks)); @@ -830,6 +835,7 @@ SegmentPtr SegmentTestBasic::buildFirstSegment( ColumnDefinesPtr cols = (!pre_define_columns) ? DMTestEnv::getDefaultColumns( is_common_handle ? DMTestEnv::PkType::CommonHandle : DMTestEnv::PkType::HiddenTiDBRowID) : pre_define_columns; + prepareColumns(cols); setColumns(cols); // Always return the first segment diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h index aa5220bf2c0..74b01c34691 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h @@ -143,6 +143,10 @@ class SegmentTestBasic : public DB::base::TiFlashStorageTestBasic const ColumnDefinesPtr & tableColumns() const { return table_columns; } + virtual Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted); + + virtual void prepareColumns(const ColumnDefinesPtr &) {} + /** * Reload a new DMContext according to latest storage status. * For example, if you have changed the settings, you should grab a new DMContext. diff --git a/dbms/src/Storages/StorageDeltaMerge.cpp b/dbms/src/Storages/StorageDeltaMerge.cpp index e59ba187ceb..f4f7b0a0966 100644 --- a/dbms/src/Storages/StorageDeltaMerge.cpp +++ b/dbms/src/Storages/StorageDeltaMerge.cpp @@ -47,6 +47,7 @@ #include #include #include +#include #include #include #include @@ -187,6 +188,7 @@ void StorageDeltaMerge::updateTableColumnInfo() if (itr != columns.end()) { col_def.default_value = itr->defaultValueToField(); + col_def.vector_index = itr->vector_index; } if (col_def.id != TiDBPkColumnID && col_def.id != VersionColumnID && col_def.id != DelMarkColumnID @@ -302,6 +304,22 @@ void StorageDeltaMerge::updateTableColumnInfo() rowkey_column_defines.push_back(handle_column_define); } rowkey_column_size = rowkey_column_defines.size(); + + LOG_INFO( + log, + "updateTableColumnInfo finished, table_name={} table_column_defines={}", + table_column_info->table_name, + [&] { + FmtBuffer fmt_buf; + fmt_buf.joinStr( + table_column_defines.begin(), + table_column_defines.end(), + [](const ColumnDefine & col, FmtBuffer & fb) { + fb.fmtAppend("{} {} {}", col.name, col.type->getFamilyName(), col.vector_index); + }, + ", "); + return fmt_buf.toString(); + }()); } void StorageDeltaMerge::clearData() @@ -761,6 +779,13 @@ DM::RSOperatorPtr StorageDeltaMerge::buildRSOperator( else LOG_DEBUG(tracing_logger, "Rough set filter is disabled."); + ANNQueryInfoPtr ann_query_info = nullptr; + if (dag_query->ann_query_info.query_type() != tipb::ANNQueryType::InvalidQueryType) + ann_query_info = std::make_shared(dag_query->ann_query_info); + + if (ann_query_info != nullptr) + rs_operator = wrapWithANNQueryInfo(rs_operator, ann_query_info); + return rs_operator; } diff --git a/dbms/src/Storages/StorageDisaggregatedRemote.cpp b/dbms/src/Storages/StorageDisaggregatedRemote.cpp index 8dd175ae12c..4d726754fc0 100644 --- a/dbms/src/Storages/StorageDisaggregatedRemote.cpp +++ b/dbms/src/Storages/StorageDisaggregatedRemote.cpp @@ -472,6 +472,7 @@ DM::RSOperatorPtr StorageDisaggregated::buildRSOperator( auto dag_query = std::make_unique( filter_conditions.conditions, + table_scan.getANNQueryInfo(), table_scan.getPushedDownFilters(), table_scan.getColumns(), std::vector{}, diff --git a/dbms/src/Storages/tests/gtest_filter_parser.cpp b/dbms/src/Storages/tests/gtest_filter_parser.cpp index 50cc6c10a78..166f79c3a6b 100644 --- a/dbms/src/Storages/tests/gtest_filter_parser.cpp +++ b/dbms/src/Storages/tests/gtest_filter_parser.cpp @@ -103,6 +103,7 @@ DM::RSOperatorPtr FilterParserTest::generateRsOperator( const google::protobuf::RepeatedPtrField pushed_down_filters{}; // don't care pushed down filters std::unique_ptr dag_query = std::make_unique( conditions, + tipb::ANNQueryInfo{}, pushed_down_filters, table_info.columns, std::vector(), // don't care runtime filter diff --git a/dbms/src/Storages/tests/gtests_parse_push_down_filter.cpp b/dbms/src/Storages/tests/gtests_parse_push_down_filter.cpp index 4bab7d8f809..eed139e999c 100644 --- a/dbms/src/Storages/tests/gtests_parse_push_down_filter.cpp +++ b/dbms/src/Storages/tests/gtests_parse_push_down_filter.cpp @@ -94,6 +94,7 @@ DM::PushDownFilterPtr ParsePushDownFilterTest::generatePushDownFilter( } dag_query = std::make_unique( conditions, + tipb::ANNQueryInfo{}, pushed_down_filters, table_info.columns, std::vector(), // don't care runtime filter diff --git a/dbms/src/TiDB/Schema/TiDB.cpp b/dbms/src/TiDB/Schema/TiDB.cpp index 688c1547e24..35645fd30fb 100644 --- a/dbms/src/TiDB/Schema/TiDB.cpp +++ b/dbms/src/TiDB/Schema/TiDB.cpp @@ -32,6 +32,7 @@ #include #include +#include namespace DB { @@ -403,6 +404,19 @@ try json->set("state", static_cast(state)); json->set("comment", comment); + if (vector_index) + { + RUNTIME_CHECK(vector_index->kind != VectorIndexKind::INVALID); + RUNTIME_CHECK(vector_index->distance_metric != DistanceMetric::INVALID); + + Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object(); + vector_index_json->set("kind", String(magic_enum::enum_name(vector_index->kind))); + vector_index_json->set("dimension", vector_index->dimension); + vector_index_json->set("distance_metric", String(magic_enum::enum_name(vector_index->distance_metric))); + + json->set("vector_index", vector_index_json); + } + #ifndef NDEBUG // Check stringify in Debug mode std::stringstream str; @@ -452,6 +466,27 @@ try collate = type_json->get("Collate"); state = static_cast(json->getValue("state")); comment = json->getValue("comment"); + + auto vector_index_json = json->getObject("vector_index"); + if (vector_index_json) + { + vector_index = std::make_shared(); + + auto vector_kind = magic_enum::enum_cast(vector_index_json->getValue("kind")); + RUNTIME_CHECK(vector_kind.has_value()); + RUNTIME_CHECK(vector_kind.value() != VectorIndexKind::INVALID); + vector_index->kind = vector_kind.value(); + + vector_index->dimension = vector_index_json->getValue("dimension"); + RUNTIME_CHECK(vector_index->dimension > 0); + RUNTIME_CHECK(vector_index->dimension <= 16000); // Just a protection + + auto distance_metric + = magic_enum::enum_cast(vector_index_json->getValue("distance_metric")); + RUNTIME_CHECK(distance_metric.has_value()); + RUNTIME_CHECK(distance_metric.value() != DistanceMetric::INVALID); + vector_index->distance_metric = distance_metric.value(); + } } catch (const Poco::Exception & e) { diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 4b6ec38254b..49abf0ed44a 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -200,6 +201,8 @@ struct ColumnInfo SchemaState state = StateNone; String comment; + VectorIndexInfoPtr vector_index = nullptr; + #ifdef M #error "Please undefine macro M first." #endif diff --git a/dbms/src/TiDB/Schema/VectorIndex.h b/dbms/src/TiDB/Schema/VectorIndex.h new file mode 100644 index 00000000000..8ba8b0a0d98 --- /dev/null +++ b/dbms/src/TiDB/Schema/VectorIndex.h @@ -0,0 +1,105 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include + +namespace TiDB +{ + +enum class VectorIndexKind +{ + INVALID = 0, + + // Note: Field names must match TiDB's enum definition. + HNSW, +}; + +enum class DistanceMetric +{ + INVALID = 0, + + // Note: Field names must match TiDB's enum definition. + L1, + L2, + COSINE, + INNER_PRODUCT, +}; + + +struct VectorIndexInfo +{ + VectorIndexKind kind = VectorIndexKind::INVALID; + UInt64 dimension = 0; + DistanceMetric distance_metric = DistanceMetric::INVALID; +}; + +using VectorIndexInfoPtr = std::shared_ptr; + +} // namespace TiDB + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::VectorIndexKind & v, FormatContext & ctx) const -> decltype(ctx.out()) + { + return format_to(ctx.out(), "{}", magic_enum::enum_name(v)); + } +}; + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::DistanceMetric & d, FormatContext & ctx) const -> decltype(ctx.out()) + { + return format_to(ctx.out(), "{}", magic_enum::enum_name(d)); + } +}; + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::VectorIndexInfo & vi, FormatContext & ctx) const -> decltype(ctx.out()) + { + return format_to(ctx.out(), "{}:{}", vi.kind, vi.distance_metric); + } +}; + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::VectorIndexInfoPtr & vi, FormatContext & ctx) const -> decltype(ctx.out()) + { + if (!vi) + return format_to(ctx.out(), ""); + return format_to(ctx.out(), "{}", *vi); + } +};