Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: sunby <[email protected]>
  • Loading branch information
sunby committed Nov 4, 2024
1 parent fdbe970 commit a1b1595
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ INSTALL_PATH := $(PWD)/bin
LIBRARY_PATH := $(PWD)/lib
PGO_PATH := $(PWD)/configs/pgo
OS := $(shell uname -s)
mode = Release
mode = Debug

use_disk_index = OFF
ifdef disk_index
Expand Down
105 changes: 49 additions & 56 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
#include "arrow/table_builder.h"
#include "arrow/type_fwd.h"
#include "common/BitsetView.h"
#include "common/Consts.h"
#include "common/FieldDataInterface.h"
#include "common/QueryInfo.h"
#include "common/Schema.h"
#include "common/Types.h"
#include "expr/ITypeExpr.h"
#include "knowhere/comp/index_param.h"
#include "mmap/ChunkedColumn.h"
Expand All @@ -26,6 +28,7 @@
#include "segcore/SegcoreConfig.h"
#include "segcore/SegmentSealedImpl.h"
#include "test_utils/DataGen.h"
#include <memory>
#include <numeric>
#include <vector>

Expand Down Expand Up @@ -119,8 +122,10 @@ TEST(test_chunk_segment, TestSearchOnSealed) {

TEST(test_chunk_segment, TestTermExpr) {
auto schema = std::make_shared<Schema>();
auto int32_fid = schema->AddDebugField("int32", DataType::INT32, true);
auto pk_fid = schema->AddDebugField("pk", DataType::INT32, true);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true);
auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true);
schema->AddField(FieldName("ts"), TimestampFieldID, DataType::INT64, true);
schema->set_primary_field_id(pk_fid);
auto segment =
segcore::CreateSealedSegment(schema,
nullptr,
Expand All @@ -130,72 +135,60 @@ TEST(test_chunk_segment, TestTermExpr) {
false,
true);
// generate test data
std::shared_ptr<arrow::Schema> arrow_schema;
auto arrow_i32_field = arrow::field("int32", arrow::int32());
auto arrow_pk_field = arrow::field("pk", arrow::int32());
arrow_schema = arrow::schema({arrow_i32_field, arrow_pk_field});

size_t test_data_count = 1000;
std::vector<int32_t> test_data(test_data_count);
std::vector<int64_t> test_data(test_data_count);
std::iota(test_data.begin(), test_data.end(), 0);
auto builder = std::make_shared<arrow::Int32Builder>();
auto builder = std::make_shared<arrow::Int64Builder>();
auto status = builder->AppendValues(test_data.begin(), test_data.end());
ASSERT_TRUE(status.ok());
auto res = builder->Finish();
ASSERT_TRUE(res.ok());
std::shared_ptr<arrow::Array> arrow_int32;
arrow_int32 = res.ValueOrDie();

auto record_batch = arrow::RecordBatch::Make(
arrow_schema, arrow_int32->length(), {arrow_int32});

// int32 field data
auto res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader = res2.ValueOrDie();
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader2 = res2.ValueOrDie();

// pk field data
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_pk_reader = res2.ValueOrDie();
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_pk_reader2 = res2.ValueOrDie();

// load int32 field
FieldDataInfo i32_field_info;
i32_field_info.field_id = int32_fid.get();
i32_field_info.row_count = test_data_count * 2;
i32_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_reader, nullptr, nullptr));
i32_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_reader2, nullptr, nullptr));
i32_field_info.arrow_reader_channel->close();
segment->LoadFieldData(int32_fid, i32_field_info);

// load pk field
FieldDataInfo pk_field_info;
pk_field_info.field_id = pk_fid.get();
pk_field_info.row_count = test_data_count * 2;
pk_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_pk_reader, nullptr, nullptr));
pk_field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_pk_reader2, nullptr, nullptr));
pk_field_info.arrow_reader_channel->close();
segment->LoadFieldData(pk_fid, pk_field_info);

// query int32 expr
std::shared_ptr<arrow::Array> arrow_int64;
arrow_int64 = res.ValueOrDie();

auto arrow_i64_field = arrow::field("int64", arrow::int64());
auto arrow_pk_field = arrow::field("pk", arrow::int64());
auto arrow_ts_field = arrow::field("ts", arrow::int64());
std::vector<std::shared_ptr<arrow::Field>> arrow_fields = {
arrow_i64_field, arrow_pk_field, arrow_ts_field};
std::vector<FieldId> field_ids = {int64_fid, pk_fid, TimestampFieldID};

for (int i = 0; i < arrow_fields.size(); i++) {
auto f = arrow_fields[i];
auto fid = field_ids[i];
auto arrow_schema =
std::make_shared<arrow::Schema>(arrow::FieldVector(1, f));
auto record_batch = arrow::RecordBatch::Make(
arrow_schema, arrow_int64->length(), {arrow_int64});

auto res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader = res2.ValueOrDie();
res2 = arrow::RecordBatchReader::Make({record_batch});
ASSERT_TRUE(res2.ok());
auto arrow_reader2 = res2.ValueOrDie();

FieldDataInfo field_info;
field_info.field_id = fid.get();
field_info.row_count = test_data_count * 2;
field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(arrow_reader, nullptr, nullptr));
field_info.arrow_reader_channel->push(
std::make_shared<ArrowDataWrapper>(
arrow_reader2, nullptr, nullptr));
field_info.arrow_reader_channel->close();
segment->LoadFieldData(fid, field_info);
}

// query int64 expr
std::vector<proto::plan::GenericValue> filter_data;
for (int i = 0; i < 10; ++i) {
proto::plan::GenericValue v;
v.set_int64_val(i);
filter_data.push_back(v);
}
auto term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(int32_fid, DataType::INT32), filter_data);
expr::ColumnInfo(int64_fid, DataType::INT64), filter_data);
BitsetType final;
auto plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
term_filter_expr);
Expand All @@ -205,7 +198,7 @@ TEST(test_chunk_segment, TestTermExpr) {

// query pk expr
auto pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(pk_fid, DataType::INT32), filter_data);
expr::ColumnInfo(pk_fid, DataType::INT64), filter_data);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
Expand Down

0 comments on commit a1b1595

Please sign in to comment.