From 78121731d602161338d84495c3e18de8606b9981 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Fri, 13 Sep 2024 17:12:03 +0800 Subject: [PATCH] fix: Fix improper use of offset in HybridSearch Signed-off-by: zhenshan.cao --- internal/proxy/search_util.go | 65 ++++++++++++++++++++++++------ internal/proxy/task_search.go | 12 ++++-- internal/proxy/task_search_test.go | 30 +++++++++++--- 3 files changed, 85 insertions(+), 22 deletions(-) diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 544fe85c4de93..00097ba4d3e3d 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -26,32 +26,73 @@ type rankParams struct { roundDecimal int64 } -// parseSearchInfo returns QueryInfo and offset -func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) { - // 0. parse iterator field - isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) +func (r *rankParams) GetLimit() int64 { + if r != nil { + return r.limit + } + return 0 +} - // 1. parse offset and real topk - topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) - if err != nil { - return nil, 0, errors.New(TopKKey + " not found in search_params") +func (r *rankParams) GetOffset() int64 { + if r != nil { + return r.offset } - topK, err := strconv.ParseInt(topKStr, 0, 64) + return 0 +} + +func (r *rankParams) GetRoundDecimal() int64 { + if r != nil { + return r.roundDecimal + } + return 0 +} + +func (r *rankParams) String() string { + return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal()) +} + +// parseSearchInfo returns QueryInfo and offset +func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) { + var topK int64 + isAdvanced := rankParams != nil + externalLimit := rankParams.GetLimit() + rankParams.GetOffset() + topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) + if externalLimit <= 0 { + return nil, 0, fmt.Errorf("%s is required", TopKKey) + } + topK = externalLimit + } else { + topKInParam, err := strconv.ParseInt(topKStr, 0, 64) + if err != nil { + if externalLimit <= 0 { + return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) + } + topK = externalLimit + } else { + if topKInParam < externalLimit { + topK = externalLimit + } else { + topK = topKInParam + } + } } + + isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + if err := validateLimit(topK); err != nil { if isIterator == "True" { - topK = Params.QuotaConfig.TopKLimit.GetAsInt64() // 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem // 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here + topK = Params.QuotaConfig.TopKLimit.GetAsInt64() } else { return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) } } var offset int64 - if !ignoreOffset { + // ignore offset if isAdvanced + if !isAdvanced { offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair) if err == nil { offset, err = strconv.ParseInt(offsetStr, 0, 64) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index e32603e4b2079..b93640e8d1f21 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -170,8 +170,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if t.SearchRequest.GetIsAdvanced() { t.rankParams, err = parseRankParams(t.request.GetSearchParams()) if err != nil { + log.Info("parseRankParams failed", zap.Error(err)) return err } + } else { + t.rankParams = nil } // Manually update nq if not set. nq, err := t.checkNq(ctx) @@ -343,11 +346,12 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]() log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) + // fetch search_growing from search param t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs())) t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs())) for index, subReq := range t.request.GetSubReqs() { - plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), true) + plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl()) if err != nil { return err } @@ -423,7 +427,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) // fetch search_growing from search param - plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), false) + plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl()) if err != nil { return err } @@ -469,7 +473,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { return nil } -func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, ignoreOffset bool) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) { +func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) { annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params) if err != nil || len(annsFieldName) == 0 { vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) @@ -482,7 +486,7 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string } annsFieldName = vecFields[0].Name } - queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, ignoreOffset) + queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) if parseErr != nil { return nil, nil, 0, parseErr } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index b0f9d769b21aa..97767d7c6ee63 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1935,7 +1935,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.NoError(t, task.Execute(ctx)) } -func TestTaskSearch_parseQueryInfo(t *testing.T) { +func TestTaskSearch_parseSearchInfo(t *testing.T) { t.Run("parseSearchInfo no error", func(t *testing.T) { var targetOffset int64 = 200 @@ -1971,7 +1971,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.validParams, nil, false) + info, offset, err := parseSearchInfo(test.validParams, nil, nil) assert.NoError(t, err) assert.NotNil(t, info) if test.description == "offsetParam" { @@ -1981,6 +1981,24 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { } }) + t.Run("parseSearchInfo externalLimit", func(t *testing.T) { + var externalLimit int64 = 200 + offsetParam := getValidSearchParams() + offsetParam = append(offsetParam, &commonpb.KeyValuePair{ + Key: OffsetKey, + Value: strconv.FormatInt(10, 10), + }) + rank := &rankParams{ + limit: externalLimit, + } + + info, offset, err := parseSearchInfo(offsetParam, nil, rank) + assert.NoError(t, err) + assert.NotNil(t, info) + assert.Equal(t, externalLimit, info.GetTopk()) + assert.Equal(t, int64(0), offset) + }) + t.Run("parseSearchInfo error", func(t *testing.T) { spNoTopk := []*commonpb.KeyValuePair{{ Key: AnnsFieldKey, @@ -2060,7 +2078,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.invalidParams, nil, false) + info, offset, err := parseSearchInfo(test.invalidParams, nil, nil) assert.Error(t, err) assert.Nil(t, info) assert.Zero(t, offset) @@ -2087,7 +2105,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, false) + info, _, err := parseSearchInfo(normalParam, schema, nil) assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) @@ -2106,7 +2124,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, false) + info, _, err := parseSearchInfo(normalParam, schema, nil) assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) @@ -2125,7 +2143,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, false) + info, _, err := parseSearchInfo(normalParam, schema, nil) assert.NotNil(t, info) assert.NoError(t, err) assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)