Skip to content

Commit

Permalink
fix: fix chunked segment term filter expression and add ut
Browse files Browse the repository at this point in the history
Signed-off-by: sunby <[email protected]>
  • Loading branch information
sunby committed Nov 3, 2024
1 parent f190e5d commit fdbe970
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 39 deletions.
12 changes: 4 additions & 8 deletions internal/core/src/exec/expression/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,10 @@ class SegmentExpr : public Expr {
if (segment_->type() == SegmentType::Sealed) {
// first is the raw data, second is valid_data
// use valid_data to see if raw data is null
auto data_vec = segment_
->get_batch_views<T>(
field_id_, i, data_pos, size)
.first;
auto valid_data = segment_
->get_batch_views<T>(
field_id_, i, data_pos, size)
.second;
auto fetched_data = segment_->get_batch_views<T>(

Check warning on line 366 in internal/core/src/exec/expression/Expr.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/Expr.h#L366

Added line #L366 was not covered by tests
field_id_, i, data_pos, size);
auto data_vec = fetched_data.first;
auto valid_data = fetched_data.second;

Check warning on line 369 in internal/core/src/exec/expression/Expr.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/Expr.h#L368-L369

Added lines #L368 - L369 were not covered by tests
func(data_vec.data(),
valid_data.data(),
size,
Expand Down
44 changes: 18 additions & 26 deletions internal/core/src/mmap/ChunkedColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class ChunkedColumnBase : public ColumnBase {
return true;
}

bool
IsValid(int64_t chunk_id, int64_t offset) const {
if (nullable_) {
return chunks_[chunk_id]->isValid(offset);

Check warning on line 101 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L100-L101

Added lines #L100 - L101 were not covered by tests
}
return true;
}

bool
IsNullable() const {
return nullable_;
Expand Down Expand Up @@ -136,7 +144,7 @@ class ChunkedColumnBase : public ColumnBase {

// used for sequential access for search
virtual BufferView
GetBatchBuffer(int64_t start_offset, int64_t length) {
GetBatchBuffer(int64_t chunk_id, int64_t start_offset, int64_t length) {

Check warning on line 147 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L147

Added line #L147 was not covered by tests
PanicInfo(ErrorCode::Unsupported,
"GetBatchBuffer only supported for VariableColumn");
}
Expand Down Expand Up @@ -313,33 +321,17 @@ class ChunkedVariableColumn : public ChunkedColumnBase {
}

BufferView
GetBatchBuffer(int64_t start_offset, int64_t length) override {
if (start_offset < 0 || start_offset > num_rows_ ||
start_offset + length > num_rows_) {
PanicInfo(ErrorCode::OutOfRange, "index out of range");
}

int chunk_num = chunks_.size();

auto [start_chunk_id, start_offset_in_chunk] =
GetChunkIDByOffset(start_offset);
GetBatchBuffer(int64_t chunk_id,

Check warning on line 324 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L324

Added line #L324 was not covered by tests
int64_t start_offset,
int64_t length) override {
BufferView buffer_view;

std::vector<BufferView::Element> elements;
for (; start_chunk_id < chunk_num && length > 0; ++start_chunk_id) {
int chunk_size = chunks_[start_chunk_id]->RowNums();
int len =
std::min(int64_t(chunk_size - start_offset_in_chunk), length);
elements.push_back(
{chunks_[start_chunk_id]->Data(),
std::dynamic_pointer_cast<StringChunk>(chunks_[start_chunk_id])
->Offsets(),
static_cast<int>(start_offset_in_chunk),
static_cast<int>(start_offset_in_chunk + len)});

start_offset_in_chunk = 0;
length -= len;
}
elements.push_back(
{chunks_[chunk_id]->Data(),

Check warning on line 330 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L329-L330

Added lines #L329 - L330 were not covered by tests
std::dynamic_pointer_cast<StringChunk>(chunks_[chunk_id])
->Offsets(),

Check warning on line 332 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L332

Added line #L332 was not covered by tests
static_cast<int>(start_offset),
static_cast<int>(start_offset + length)});

Check warning on line 334 in internal/core/src/mmap/ChunkedColumn.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/mmap/ChunkedColumn.h#L334

Added line #L334 was not covered by tests

buffer_view.data_ = elements;
return buffer_view;
Expand Down
15 changes: 10 additions & 5 deletions internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,13 @@ ChunkedSegmentSealedImpl::get_chunk_buffer(FieldId field_id,
if (field_data->IsNullable()) {
valid_data.reserve(length);
for (int i = 0; i < length; i++) {
valid_data.push_back(field_data->IsValid(start_offset + i));
valid_data.push_back(
field_data->IsValid(chunk_id, start_offset + i));

Check warning on line 780 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L779-L780

Added lines #L779 - L780 were not covered by tests
}
}
return std::make_pair(field_data->GetBatchBuffer(start_offset, length),
valid_data);
return std::make_pair(
field_data->GetBatchBuffer(chunk_id, start_offset, length),
valid_data);

Check warning on line 785 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L783-L785

Added lines #L783 - L785 were not covered by tests
}
PanicInfo(ErrorCode::UnexpectedError,
"get_chunk_buffer only used for variable column field");
Expand Down Expand Up @@ -1201,8 +1203,9 @@ ChunkedSegmentSealedImpl::search_pk(const PkType& pk,
[](const int64_t& elem, const int64_t& value) {
return elem < value;
});
auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i);

Check warning on line 1206 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L1206

Added line #L1206 was not covered by tests
for (; it != src + chunk_row_num && *it == target; it++) {
auto offset = it - src;
auto offset = it - src + num_rows_until_chunk;

Check warning on line 1208 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L1208

Added line #L1208 was not covered by tests
if (insert_record_.timestamps_[offset] <= timestamp) {
pk_offsets.emplace_back(offset);
}
Expand All @@ -1220,8 +1223,10 @@ ChunkedSegmentSealedImpl::search_pk(const PkType& pk,
for (int i = 0; i < num_chunk; ++i) {
auto views = var_column->StringViews(i).first;
auto it = std::lower_bound(views.begin(), views.end(), target);
auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i);

Check warning on line 1226 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L1226

Added line #L1226 was not covered by tests
for (; it != views.end() && *it == target; it++) {
auto offset = std::distance(views.begin(), it);
auto offset =
std::distance(views.begin(), it) + num_rows_until_chunk;

Check warning on line 1229 in internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp#L1228-L1229

Added lines #L1228 - L1229 were not covered by tests
if (insert_record_.timestamps_[offset] <= timestamp) {
pk_offsets.emplace_back(offset);
}
Expand Down
104 changes: 104 additions & 0 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include "arrow/table_builder.h"
#include "arrow/type_fwd.h"
#include "common/BitsetView.h"
#include "common/FieldDataInterface.h"
#include "common/QueryInfo.h"
#include "common/Schema.h"
#include "expr/ITypeExpr.h"
#include "knowhere/comp/index_param.h"
#include "mmap/ChunkedColumn.h"
#include "query/ExecPlanNodeVisitor.h"
#include "query/SearchOnSealed.h"
#include "segcore/SegcoreConfig.h"
#include "segcore/SegmentSealedImpl.h"
#include "test_utils/DataGen.h"
#include <numeric>
#include <vector>

struct DeferRelease {
Expand Down Expand Up @@ -108,3 +116,99 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
ASSERT_TRUE(offsets.find(i) != offsets.end());
}
}

TEST(test_chunk_segment, TestTermExpr) {
auto schema = std::make_shared<Schema>();
auto int32_fid = schema->AddDebugField("int32", DataType::INT32, true);
auto pk_fid = schema->AddDebugField("pk", DataType::INT32, true);
auto segment =
segcore::CreateSealedSegment(schema,
nullptr,
-1,
segcore::SegcoreConfig::default_config(),
false,
false,
true);
// generate test data
std::shared_ptr<arrow::Schema> arrow_schema;
auto arrow_i32_field = arrow::field("int32", arrow::int32());
auto arrow_pk_field = arrow::field("pk", arrow::int32());
arrow_schema = arrow::schema({arrow_i32_field, arrow_pk_field});

size_t test_data_count = 1000;
std::vector<int32_t> test_data(test_data_count);
std::iota(test_data.begin(), test_data.end(), 0);
auto builder = std::make_shared<arrow::Int32Builder>();
auto status = builder->AppendValues(test_data.begin(), test_data.end());
ASSERT_TRUE(status.ok());
auto res = builder->Finish();
ASSERT_TRUE(res.ok());
std::shared_ptr<arrow::Array> arrow_int32;
arrow_int32 = res.ValueOrDie();

auto record_batch = arrow::RecordBatch::Make(
arrow_schema, arrow_int32->length(), {arrow_int32});

// int32 field data
auto res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader = res2.ValueOrDie();
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader2 = res2.ValueOrDie();

// pk field data
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_pk_reader = res2.ValueOrDie();
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_pk_reader2 = res2.ValueOrDie();

// load int32 field
FieldDataInfo i32_field_info;
i32_field_info.field_id = int32_fid.get();
i32_field_info.row_count = test_data_count * 2;
i32_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_reader, nullptr, nullptr));
i32_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_reader2, nullptr, nullptr));
i32_field_info.arrow_reader_channel->close();
segment->LoadFieldData(int32_fid, i32_field_info);

// load pk field
FieldDataInfo pk_field_info;
pk_field_info.field_id = pk_fid.get();
pk_field_info.row_count = test_data_count * 2;
pk_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_pk_reader, nullptr, nullptr));
pk_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_pk_reader2, nullptr, nullptr));
pk_field_info.arrow_reader_channel->close();
segment->LoadFieldData(pk_fid, pk_field_info);

// query int32 expr
std::vector<proto::plan::GenericValue> filter_data;
for (int i = 0; i < 10; ++i) {
proto::plan::GenericValue v;
v.set_int64_val(i);
filter_data.push_back(v);
}
auto term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(int32_fid, DataType::INT32), filter_data);
BitsetType final;
auto plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), 2 * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(20, final.count());

// query pk expr
auto pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(pk_fid, DataType::INT32), filter_data);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), 2 * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(20, final.count());
}

0 comments on commit fdbe970

Please sign in to comment.