diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index b5298d01b4346..75d54987ea607 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -66,8 +66,9 @@ struct OffsetDisPairComparator { }; struct VectorIterator { public: - VectorIterator(int chunk_count, int64_t chunk_rows = -1) - : chunk_rows_(chunk_rows) { + VectorIterator(int chunk_count, + const std::vector& total_rows_until_chunk = {}) + : total_rows_until_chunk_(total_rows_until_chunk) { iterators_.reserve(chunk_count); } @@ -119,7 +120,7 @@ struct VectorIterator { private: int64_t convert_to_segment_offset(int64_t chunk_offset, int chunk_idx) { - if (chunk_rows_ == -1) { + if (total_rows_until_chunk_.size() == 0) { AssertInfo( iterators_.size() == 1, "Wrong state for vectorIterators, which having incorrect " @@ -129,7 +130,7 @@ struct VectorIterator { iterators_.size()); return chunk_offset; } - return chunk_idx * chunk_rows_ + chunk_offset; + return total_rows_until_chunk_[chunk_idx] + chunk_offset; } private: @@ -139,7 +140,7 @@ struct VectorIterator { OffsetDisPairComparator> heap_; bool sealed = false; - int64_t chunk_rows_ = -1; + std::vector total_rows_until_chunk_; //currently, VectorIterator is guaranteed to be used serially without concurrent problem, in the future //we may need to add mutex to protect the variable sealed }; @@ -163,7 +164,7 @@ struct SearchResult { AssembleChunkVectorIterators( int64_t nq, int chunk_count, - int64_t rows_per_chunk, + const std::vector& total_rows_until_chunk, const std::vector& kw_iterators) { AssertInfo(kw_iterators.size() == nq * chunk_count, "kw_iterators count:{} is not equal to nq*chunk_count:{}, " @@ -176,7 +177,7 @@ struct SearchResult { vec_iter_idx = vec_iter_idx % nq; if (vector_iterators.size() < nq) { auto vector_iterator = std::make_shared( - chunk_count, rows_per_chunk); + chunk_count, total_rows_until_chunk); vector_iterators.emplace_back(vector_iterator); } auto kw_iterator = kw_iterators[i]; diff --git a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h index e6a95c6603809..640789518cdf1 100644 --- a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h @@ -140,7 +140,7 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info, index.VectorIterators(dataset, search_conf, bitset); if (iterators_val.has_value()) { search_result.AssembleChunkVectorIterators( - nq, 1, -1, iterators_val.value()); + nq, 1, {0}, iterators_val.value()); } else { LOG_ERROR( "Returned knowhere iterator has non-ready iterators " diff --git a/internal/core/src/mmap/ChunkedColumn.h b/internal/core/src/mmap/ChunkedColumn.h index 91a7bf230b3ca..2a9e3ff6db40b 100644 --- a/internal/core/src/mmap/ChunkedColumn.h +++ b/internal/core/src/mmap/ChunkedColumn.h @@ -165,6 +165,11 @@ class ChunkedColumnBase : public ColumnBase { return num_rows_until_chunk_[chunk_id]; } + const std::vector& + GetNumRowsUntilChunk() const { + return num_rows_until_chunk_; + } + protected: bool nullable_{false}; size_t num_rows_{0}; diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 0222ce2cd8b74..f71efb7562a6b 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -150,11 +150,12 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, } } if (info.group_by_field_id_.has_value()) { + std::vector chunk_rows(max_chunk, 0); + for (int i = 1; i < max_chunk; ++i) { + chunk_rows[i] = i * vec_size_per_chunk; + } search_result.AssembleChunkVectorIterators( - num_queries, - max_chunk, - vec_size_per_chunk, - final_qr.chunk_iterators()); + num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators()); } else { search_result.distances_ = std::move(final_qr.mutable_distances()); search_result.seg_offsets_ = diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 9b3a4df287599..2bd7e8edb8ac6 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -161,8 +161,10 @@ SearchOnSealed(const Schema& schema, offset += chunk_size; } if (search_info.group_by_field_id_.has_value()) { - result.AssembleChunkVectorIterators( - num_queries, 1, -1, final_qr.chunk_iterators()); + result.AssembleChunkVectorIterators(num_queries, + num_chunk, + column->GetNumRowsUntilChunk(), + final_qr.chunk_iterators()); } else { result.distances_ = std::move(final_qr.mutable_distances()); result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets()); @@ -201,7 +203,7 @@ SearchOnSealed(const Schema& schema, auto sub_qr = BruteForceSearchIterators( dataset, vec_data, row_count, search_info, bitset, data_type); result.AssembleChunkVectorIterators( - num_queries, 1, -1, sub_qr.chunk_iterators()); + num_queries, 1, {0}, sub_qr.chunk_iterators()); } else { auto sub_qr = BruteForceSearch( dataset, vec_data, row_count, search_info, bitset, data_type); diff --git a/internal/core/unittest/test_chunked_segment.cpp b/internal/core/unittest/test_chunked_segment.cpp index b0b624b68a793..d9b34218bdc7b 100644 --- a/internal/core/unittest/test_chunked_segment.cpp +++ b/internal/core/unittest/test_chunked_segment.cpp @@ -92,7 +92,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) { search_info, query_data, 1, - chunk_size * chunk_num, + total_row_count, bv, search_result); @@ -107,4 +107,31 @@ TEST(test_chunk_segment, TestSearchOnSealed) { for (int i = 0; i < total_row_count; i++) { ASSERT_TRUE(offsets.find(i) != offsets.end()); } + + // test with group by + search_info.group_by_field_id_ = fakevec_id; + std::fill(bitset_data, bitset_data + bitset_size, 0); + query::SearchOnSealed(*schema, + column, + search_info, + query_data, + 1, + total_row_count, + bv, + search_result); + + ASSERT_EQ(1, search_result.vector_iterators_->size()); + + auto iter = search_result.vector_iterators_->at(0); + // collect all offsets + offsets.clear(); + while (iter->HasNext()) { + auto [offset, distance] = iter->Next().value(); + offsets.insert(offset); + } + + ASSERT_EQ(total_row_count, offsets.size()); + for (int i = 0; i < total_row_count; i++) { + ASSERT_TRUE(offsets.find(i) != offsets.end()); + } }