From ea06a530501ff5ad51ea244ff0bba0d1cb98601a Mon Sep 17 00:00:00 2001 From: lixinguo Date: Sat, 14 Sep 2024 19:33:47 +0800 Subject: [PATCH] enhance: add search params in search request in restful Signed-off-by: lixinguo --- internal/datacoord/import_util.go | 3 +- .../proxy/httpserver/handler_v2.go | 80 ++++++++----------- .../proxy/httpserver/handler_v2_test.go | 14 +++- .../proxy/httpserver/request_v2.go | 64 ++++++++------- .../distributed/proxy/httpserver/utils.go | 12 +++ .../proxy/httpserver/utils_test.go | 13 +++ pkg/common/common.go | 2 + 7 files changed, 109 insertions(+), 79 deletions(-) diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index 84af224e51060..3408744fe4bcb 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -23,14 +23,13 @@ import ( "sort" "time" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/importutilv2" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 857dc6d07de54..80a9b7e4829cd 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -905,45 +905,30 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { - params := map[string]interface{}{ // auto generated mapping - "level": int(commonpb.ConsistencyLevel_Bounded), - } - if reqParams != nil { - radius, radiusOk := reqParams[ParamRadius] - rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter] - if rangeFilterOk { - if !radiusOk { - log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), - HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", - }) - return nil, merr.ErrIncorrectParameterFormat - } - params[ParamRangeFilter] = rangeFilter - } - if radiusOk { - params[ParamRadius] = radius - } - } - bs, _ := json.Marshal(params) - searchParams := []*commonpb.KeyValuePair{ - {Key: Params, Value: string(bs)}, - } +func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) ([]*commonpb.KeyValuePair, error) { + var searchParams []*commonpb.KeyValuePair + bs, _ := json.Marshal(reqSearchParams.Params) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)}) + // need to exposure ParamRoundDecimal in req? + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) return searchParams, nil } func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + } + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) + if err != nil { + return nil, err } c.Set(ContextRequest, req) @@ -951,7 +936,8 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN if err != nil { return nil, err } - searchParams, err := generateSearchParams(ctx, c, httpReq.Params) + + searchParams, err := generateSearchParams(ctx, c, httpReq.SearchParams) if err != nil { return nil, err } @@ -959,7 +945,6 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) body, _ := c.Get(gin.BodyBytesKey) placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) if err != nil { @@ -1005,6 +990,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq Requests: []*milvuspb.SearchRequest{}, OutputFields: httpReq.OutputFields, } + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) + if err != nil { + return nil, err + } c.Set(ContextRequest, req) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) @@ -1014,7 +1004,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams, err := generateSearchParams(ctx, c, subReq.Params) + searchParams, err := generateSearchParams(ctx, c, subReq.SearchParams) if err != nil { return nil, err } @@ -1022,7 +1012,6 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) @@ -1033,15 +1022,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq return nil, err } searchReq := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: subReq.Filter, - PlaceholderGroup: placeholderGroup, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - SearchParams: searchParams, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: subReq.Filter, + PlaceholderGroup: placeholderGroup, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, } req.Requests = append(req.Requests, searchReq) } diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 1728dac1017fc..3582c30f1aba2 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1349,7 +1349,7 @@ func TestSearchV2(t *testing.T) { Schema: generateCollectionSchema(schemapb.DataType_Int64), ShardsNum: ShardNumDefault, Status: &StatusSuccess, - }, nil).Times(12) + }, nil).Times(11) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{ TopK: int64(3), OutputFields: outputFields, @@ -1398,7 +1398,7 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1406,8 +1406,8 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), - errMsg: "can only accept json format request, error: invalid search params", + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignore_growing": "true"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignore_growing of type bool", errCode: 1801, // ErrIncorrectParameterFormat }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1529,6 +1529,12 @@ func TestSearchV2(t *testing.T) { path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}", + errCode: 1801, // ErrIncorrectParameterFormat + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index b292b73a82e65..650924f2ac337 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -141,18 +141,28 @@ type CollectionDataReq struct { func (req *CollectionDataReq) GetDbName() string { return req.DbName } +type searchParams struct { + // not use metricType any more, just for compatibility + MetricType string `json:"metricType"` + Params map[string]interface{} `json:"params"` + IgnoreGrowing bool `json:"ignore_growing"` +} + type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - Params map[string]float64 `json:"params"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + // not use Params any more, just for compatibility + Params map[string]float64 `json:"params"` } func (req *SearchReqV2) GetDbName() string { return req.DbName } @@ -163,25 +173,25 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - IgnoreGrowing bool `json:"ignoreGrowing"` - Params map[string]float64 `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"searchParams"` } type HybridSearchReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - Search []SubSearchReq `json:"search"` - Rerank Rand `json:"rerank"` - Limit int32 `json:"limit"` - OutputFields []string `json:"outputFields"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Search []SubSearchReq `json:"search"` + Rerank Rand `json:"rerank"` + Limit int32 `json:"limit"` + OutputFields []string `json:"outputFields"` + ConsistencyLevel string `json:"consistencyLevel"` } func (req *HybridSearchReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 7d534927ed8b6..af4c665036575 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1287,3 +1287,15 @@ func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.SuccessLabel).Inc() return nil, nil } + +func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) { + if reqConsistencyLevel != "" { + level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel] + if !ok { + return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel)) + } + return commonpb.ConsistencyLevel(level), false, nil + } + // ConsistencyLevel_Session default in PyMilvus + return commonpb.ConsistencyLevel_Session, true, nil +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 897728f38de51..b5a1af02a9532 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1372,3 +1372,16 @@ func TestBuildQueryResps(t *testing.T) { _, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true) assert.Equal(t, nil, err) } + +func TestConvertConsistencyLevel(t *testing.T) { + consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Session) + assert.Equal(t, true, useDefaultConsistency) + consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong) + assert.Equal(t, false, useDefaultConsistency) + consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("test") + assert.NotNil(t, err) +} diff --git a/pkg/common/common.go b/pkg/common/common.go index 94f361da4a316..a1c81d098908c 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -133,6 +133,8 @@ const ( IsSparseKey = "is_sparse" AutoIndexName = "AUTOINDEX" BitmapCardinalityLimitKey = "bitmap_cardinality_limit" + IgnoreGrowing = "ignore_growing" + ConsistencyLevel = "consistency_level" ) // Collection properties key