From 1b4f7e3ac118b812f03c4cb84a281d3d5a3fa574 Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Thu, 14 Nov 2024 18:40:32 +0800 Subject: [PATCH] enhance: Add more expr ut for chunked segment (#37600) related pr: #37570 --------- Signed-off-by: sunby --- .../core/unittest/test_chunked_segment.cpp | 209 ++++++++++++------ 1 file changed, 138 insertions(+), 71 deletions(-) diff --git a/internal/core/unittest/test_chunked_segment.cpp b/internal/core/unittest/test_chunked_segment.cpp index 97cf3fc03be75..3761382e00ce1 100644 --- a/internal/core/unittest/test_chunked_segment.cpp +++ b/internal/core/unittest/test_chunked_segment.cpp @@ -24,13 +24,17 @@ #include "knowhere/comp/index_param.h" #include "mmap/ChunkedColumn.h" #include "mmap/Types.h" +#include "pb/plan.pb.h" #include "query/ExecPlanNodeVisitor.h" #include "query/SearchOnSealed.h" #include "segcore/SegcoreConfig.h" +#include "segcore/SegmentSealed.h" #include "segcore/SegmentSealedImpl.h" #include "test_utils/DataGen.h" #include #include +#include +#include #include struct DeferRelease { @@ -148,79 +152,119 @@ TEST(test_chunk_segment, TestSearchOnSealed) { } } -TEST(test_chunk_segment, TestTermExpr) { - auto schema = std::make_shared(); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true); - auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true); - schema->AddField(FieldName("ts"), TimestampFieldID, DataType::INT64, true); - schema->set_primary_field_id(pk_fid); - auto segment = - segcore::CreateSealedSegment(schema, - nullptr, - -1, - segcore::SegcoreConfig::default_config(), - false, - false, - true); - size_t test_data_count = 1000; - - auto arrow_i64_field = arrow::field("int64", arrow::int64()); - auto arrow_pk_field = arrow::field("pk", arrow::int64()); - auto arrow_ts_field = arrow::field("ts", arrow::int64()); - std::vector> arrow_fields = { - arrow_i64_field, arrow_pk_field, arrow_ts_field}; - - std::vector field_ids = {int64_fid, pk_fid, TimestampFieldID}; - - int start_id = 1; - int chunk_num = 2; - - std::vector field_infos; - for (auto fid : field_ids) { - FieldDataInfo field_info; - field_info.field_id = fid.get(); - field_info.row_count = test_data_count * chunk_num; - field_infos.push_back(field_info); - } +class TestChunkSegment : public testing::Test { + protected: + void + SetUp() override { + auto schema = std::make_shared(); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true); + auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true); + auto str_fid = + schema->AddDebugField("string1", DataType::VARCHAR, true); + auto str2_fid = + schema->AddDebugField("string2", DataType::VARCHAR, true); + schema->AddField( + FieldName("ts"), TimestampFieldID, DataType::INT64, true); + schema->set_primary_field_id(pk_fid); + segment = segcore::CreateSealedSegment( + schema, + nullptr, + -1, + segcore::SegcoreConfig::default_config(), + false, + false, + true); + test_data_count = 1000; + + auto arrow_i64_field = arrow::field("int64", arrow::int64()); + auto arrow_pk_field = arrow::field("pk", arrow::int64()); + auto arrow_ts_field = arrow::field("ts", arrow::int64()); + auto arrow_str_field = arrow::field("string1", arrow::int64()); + auto arrow_str2_field = arrow::field("string2", arrow::int64()); + std::vector> arrow_fields = { + arrow_i64_field, + arrow_pk_field, + arrow_ts_field, + arrow_str_field, + arrow_str2_field}; + + std::vector field_ids = { + int64_fid, pk_fid, TimestampFieldID, str_fid, str2_fid}; + fields = {{"int64", int64_fid}, + {"pk", pk_fid}, + {"ts", TimestampFieldID}, + {"string1", str_fid}, + {"string2", str2_fid}}; + + int start_id = 1; + chunk_num = 2; + + std::vector field_infos; + for (auto fid : field_ids) { + FieldDataInfo field_info; + field_info.field_id = fid.get(); + field_info.row_count = test_data_count * chunk_num; + field_infos.push_back(field_info); + } - // generate data - for (int chunk_id = 0; chunk_id < chunk_num; - chunk_id++, start_id += test_data_count) { - std::vector test_data(test_data_count); - std::iota(test_data.begin(), test_data.end(), start_id); - - auto builder = std::make_shared(); - auto status = builder->AppendValues(test_data.begin(), test_data.end()); - ASSERT_TRUE(status.ok()); - auto res = builder->Finish(); - ASSERT_TRUE(res.ok()); - std::shared_ptr arrow_int64; - arrow_int64 = res.ValueOrDie(); - - for (int i = 0; i < arrow_fields.size(); i++) { - auto f = arrow_fields[i]; - auto fid = field_ids[i]; - auto arrow_schema = - std::make_shared(arrow::FieldVector(1, f)); - auto record_batch = arrow::RecordBatch::Make( - arrow_schema, arrow_int64->length(), {arrow_int64}); - - auto res2 = arrow::RecordBatchReader::Make({record_batch}); - ASSERT_TRUE(res2.ok()); - auto arrow_reader = res2.ValueOrDie(); - - field_infos[i].arrow_reader_channel->push( - std::make_shared( - arrow_reader, nullptr, nullptr)); + // generate data + for (int chunk_id = 0; chunk_id < chunk_num; + chunk_id++, start_id += test_data_count) { + std::vector test_data(test_data_count); + std::iota(test_data.begin(), test_data.end(), start_id); + + auto builder = std::make_shared(); + auto status = + builder->AppendValues(test_data.begin(), test_data.end()); + ASSERT_TRUE(status.ok()); + auto res = builder->Finish(); + ASSERT_TRUE(res.ok()); + std::shared_ptr arrow_int64; + arrow_int64 = res.ValueOrDie(); + + auto str_builder = std::make_shared(); + for (int i = 0; i < test_data_count; i++) { + auto status = str_builder->Append("test" + std::to_string(i)); + ASSERT_TRUE(status.ok()); + } + std::shared_ptr arrow_str; + status = str_builder->Finish(&arrow_str); + ASSERT_TRUE(status.ok()); + + for (int i = 0; i < arrow_fields.size(); i++) { + auto f = arrow_fields[i]; + auto fid = field_ids[i]; + auto arrow_schema = + std::make_shared(arrow::FieldVector(1, f)); + + auto col = i < 3 ? arrow_int64 : arrow_str; + auto record_batch = arrow::RecordBatch::Make( + arrow_schema, arrow_int64->length(), {col}); + + auto res2 = arrow::RecordBatchReader::Make({record_batch}); + ASSERT_TRUE(res2.ok()); + auto arrow_reader = res2.ValueOrDie(); + + field_infos[i].arrow_reader_channel->push( + std::make_shared( + arrow_reader, nullptr, nullptr)); + } } - } - // load - for (int i = 0; i < field_infos.size(); i++) { - field_infos[i].arrow_reader_channel->close(); - segment->LoadFieldData(field_ids[i], field_infos[i]); + // load + for (int i = 0; i < field_infos.size(); i++) { + field_infos[i].arrow_reader_channel->close(); + segment->LoadFieldData(field_ids[i], field_infos[i]); + } } + segcore::SegmentSealedUPtr segment; + int chunk_num; + int test_data_count; + std::unordered_map fields; +}; + +TEST_F(TestChunkSegment, TestTermExpr) { // query int64 expr std::vector filter_data; for (int i = 1; i <= 10; ++i) { @@ -229,7 +273,7 @@ TEST(test_chunk_segment, TestTermExpr) { filter_data.push_back(v); } auto term_filter_expr = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), filter_data); + expr::ColumnInfo(fields.at("int64"), DataType::INT64), filter_data); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, term_filter_expr); @@ -239,7 +283,7 @@ TEST(test_chunk_segment, TestTermExpr) { // query pk expr auto pk_term_filter_expr = std::make_shared( - expr::ColumnInfo(pk_fid, DataType::INT64), filter_data); + expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data); plan = std::make_shared(DEFAULT_PLANNODE_ID, pk_term_filter_expr); final = query::ExecuteQueryExpr( @@ -252,10 +296,33 @@ TEST(test_chunk_segment, TestTermExpr) { v.set_int64_val(test_data_count + 1); filter_data2.push_back(v); pk_term_filter_expr = std::make_shared( - expr::ColumnInfo(pk_fid, DataType::INT64), filter_data2); + expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data2); plan = std::make_shared(DEFAULT_PLANNODE_ID, pk_term_filter_expr); final = query::ExecuteQueryExpr( plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP); ASSERT_EQ(1, final.count()); } + +TEST_F(TestChunkSegment, TestCompareExpr) { + auto expr = std::make_shared(fields.at("int64"), + fields.at("pk"), + DataType::INT64, + DataType::INT64, + proto::plan::OpType::Equal); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final = query::ExecuteQueryExpr( + plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP); + ASSERT_EQ(chunk_num * test_data_count, final.count()); + + expr = std::make_shared(fields.at("string1"), + fields.at("string2"), + DataType::VARCHAR, + DataType::VARCHAR, + proto::plan::OpType::Equal); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = query::ExecuteQueryExpr( + plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP); + ASSERT_EQ(chunk_num * test_data_count, final.count()); +} \ No newline at end of file