From f3285212e85fece27169ded525ec3f9c698eced5 Mon Sep 17 00:00:00 2001 From: smellthemoon <64083300+smellthemoon@users.noreply.github.com> Date: Fri, 6 Sep 2024 14:53:04 +0800 Subject: [PATCH] fix: not append valid data when transfer to insert record (#36027) fix not append valid data when transfer to insert record and add a tiny check when in groupBy field. #35924 Signed-off-by: lixinguo Co-authored-by: lixinguo --- internal/core/src/segcore/InsertRecord.h | 15 +- internal/proxy/search_util.go | 3 + internal/proxy/task_search_test.go | 19 +++ internal/storage/utils.go | 10 ++ pkg/util/typeutil/schema.go | 4 +- tests/integration/null_data/null_data_test.go | 132 ++++++++++++++++++ 6 files changed, 173 insertions(+), 10 deletions(-) diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index b622638b41e06..a731e84bab1f6 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -293,14 +293,15 @@ class ThreadSafeValidData { total += field_data->get_num_rows(); } if (length_ + total > data_.size()) { - data_.reserve(length_ + total); + data_.resize(length_ + total); } - length_ += total; + for (auto& field_data : datas) { auto num_row = field_data->get_num_rows(); for (size_t i = 0; i < num_row; i++) { - data_.push_back(field_data->is_valid(i)); + data_[length_ + i] = field_data->is_valid(i); } + length_ += num_row; } } @@ -311,14 +312,10 @@ class ThreadSafeValidData { std::unique_lock lck(mutex_); if (field_meta.is_nullable()) { if (length_ + num_rows > data_.size()) { - data_.reserve(length_ + num_rows); + data_.resize(length_ + num_rows); } - auto src = data->valid_data().data(); - for (size_t i = 0; i < num_rows; ++i) { - data_.push_back(src[i]); - // data_[length_ + i] = src[i]; - } + std::copy_n(src, num_rows, data_.data() + length_); length_ += num_rows; } } diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index ad6a90dd08857..be9e3fc4b2e10 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -108,6 +108,9 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb if groupByFieldName != "" { fields := schema.GetFields() for _, field := range fields { + if field.GetNullable() { + return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName)) + } if field.Name == groupByFieldName { groupByFieldId = field.FieldID break diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 86f8d9edcc7b5..4d53945526309 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2188,6 +2188,25 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) + t.Run("check nullable and groupBy", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "string_field", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + Nullable: true, + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + info, _, err := parseSearchInfo(normalParam, schema, false) + assert.Nil(t, info) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + }) t.Run("check iterator and topK", func(t *testing.T) { normalParam := getValidSearchParams() normalParam = append(normalParam, &commonpb.KeyValuePair{ diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 50371bab9b88b..f6566402af00d 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1040,6 +1040,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *Int8FieldData: int32Data := make([]int32, len(rawData.Data)) @@ -1058,6 +1059,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *Int16FieldData: int32Data := make([]int32, len(rawData.Data)) @@ -1076,6 +1078,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *Int32FieldData: fieldData = &schemapb.FieldData{ @@ -1090,6 +1093,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *Int64FieldData: fieldData = &schemapb.FieldData{ @@ -1104,6 +1108,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *FloatFieldData: fieldData = &schemapb.FieldData{ @@ -1118,6 +1123,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *DoubleFieldData: fieldData = &schemapb.FieldData{ @@ -1132,6 +1138,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *StringFieldData: fieldData = &schemapb.FieldData{ @@ -1146,6 +1153,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *ArrayFieldData: fieldData = &schemapb.FieldData{ @@ -1160,6 +1168,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *JSONFieldData: fieldData = &schemapb.FieldData{ @@ -1174,6 +1183,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *FloatVectorFieldData: fieldData = &schemapb.FieldData{ diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 597e08c4a9560..9f040c463fa3f 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -932,7 +932,9 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error dst = append(dst, scalarFieldData) fieldID2Data[srcFieldData.FieldId] = scalarFieldData } - dstScalar := fieldID2Data[srcFieldData.FieldId].GetScalars() + fieldData := fieldID2Data[srcFieldData.FieldId] + fieldData.ValidData = append(fieldData.ValidData, srcFieldData.GetValidData()...) + dstScalar := fieldData.GetScalars() switch srcScalar := fieldType.Scalars.Data.(type) { case *schemapb.ScalarField_BoolData: if dstScalar.GetBoolData() == nil { diff --git a/tests/integration/null_data/null_data_test.go b/tests/integration/null_data/null_data_test.go index 5a34341fa68c4..a6707ba4ff7d9 100644 --- a/tests/integration/null_data/null_data_test.go +++ b/tests/integration/null_data/null_data_test.go @@ -242,6 +242,138 @@ func (s *NullDataSuite) run() { s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode()) s.checkNullableFieldData(nullableFid.GetName(), queryResult.GetFieldsData(), start) + fieldsData[2] = integration.NewInt64FieldDataNullableWithStart(nullableFid.GetName(), rowNum, start) + fieldsDataForUpsert := make([]*schemapb.FieldData, 0) + fieldsDataForUpsert = append(fieldsDataForUpsert, integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, start)) + fieldsDataForUpsert = append(fieldsDataForUpsert, fVecColumn) + nullableFidDataForUpsert := &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: nullableFid.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{}, + }, + }, + }, + }, + ValidData: make([]bool, rowNum), + } + fieldsDataForUpsert = append(fieldsDataForUpsert, nullableFidDataForUpsert) + insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: fieldsData, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + upsertResult, err := c.Proxy.Upsert(ctx, &milvuspb.UpsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: fieldsDataForUpsert, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.Equal(upsertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + // create index + createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: fVecColumn.FieldName, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType), + }) + if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode()) + + s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName) + + desCollResp, err = c.Proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ + CollectionName: collectionName, + }) + s.NoError(err) + s.Equal(desCollResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + compactResp, err = c.Proxy.ManualCompaction(ctx, &milvuspb.ManualCompactionRequest{ + CollectionID: desCollResp.GetCollectionID(), + }) + + s.NoError(err) + s.Equal(compactResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + compacted = func() bool { + resp, err := c.Proxy.GetCompactionState(ctx, &milvuspb.GetCompactionStateRequest{ + CompactionID: compactResp.GetCompactionID(), + }) + if err != nil { + return false + } + return resp.GetState() == commonpb.CompactionState_Completed + } + for !compacted() { + time.Sleep(3 * time.Second) + } + + // load + loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason())) + } + s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) + s.WaitForLoad(ctx, collectionName) + + // flush + flushResp, err = c.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has = flushResp.GetCollSegIDs()[collectionName] + ids = segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has = flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err = c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + for _, segment := range segments { + log.Info("ShowSegments result", zap.String("segment", segment.String())) + } + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + + // search + searchResult, err = c.Proxy.Search(ctx, searchReq) + err = merr.CheckRPCCall(searchResult, err) + s.NoError(err) + s.checkNullableFieldData(nullableFid.GetName(), searchResult.GetResults().GetFieldsData(), start) + + queryResult, err = c.Proxy.Query(ctx, &milvuspb.QueryRequest{ + DbName: dbName, + CollectionName: collectionName, + Expr: expr, + OutputFields: []string{"nullableFid"}, + }) + if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode()) + s.checkNullableFieldData(nullableFid.GetName(), queryResult.GetFieldsData(), start) + // // expr will not select null data // exprResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{ // DbName: dbName,