From 41d847444e0604336e34dc89ac3fb8e87df94fae Mon Sep 17 00:00:00 2001 From: lixinguo Date: Sat, 14 Sep 2024 19:33:47 +0800 Subject: [PATCH] enhance: add search params in search request in restful Signed-off-by: lixinguo --- internal/datacoord/import_util.go | 3 +- .../proxy/httpserver/handler_v2.go | 88 +++++++++---------- .../proxy/httpserver/handler_v2_test.go | 18 ++-- .../proxy/httpserver/request_v2.go | 64 ++++++++------ pkg/common/common.go | 2 + 5 files changed, 94 insertions(+), 81 deletions(-) diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index 84af224e51060..3408744fe4bcb 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -23,14 +23,13 @@ import ( "sort" "time" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/importutilv2" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 857dc6d07de54..4b0df99a00101 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -905,45 +905,34 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { - params := map[string]interface{}{ // auto generated mapping - "level": int(commonpb.ConsistencyLevel_Bounded), - } - if reqParams != nil { - radius, radiusOk := reqParams[ParamRadius] - rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter] - if rangeFilterOk { - if !radiusOk { - log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), - HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", - }) - return nil, merr.ErrIncorrectParameterFormat - } - params[ParamRangeFilter] = rangeFilter - } - if radiusOk { - params[ParamRadius] = radius - } - } - bs, _ := json.Marshal(params) - searchParams := []*commonpb.KeyValuePair{ - {Key: Params, Value: string(bs)}, - } +func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) ([]*commonpb.KeyValuePair, error) { + var searchParams []*commonpb.KeyValuePair + bs, _ := json.Marshal(reqSearchParams.Params) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)}) + // need to exposure ParamRoundDecimal in req? + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) return searchParams, nil } func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + } + if httpReq.ConsistencyLevel != "" { + level, ok := commonpb.ConsistencyLevel_value[httpReq.ConsistencyLevel] + if !ok { + return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", httpReq.ConsistencyLevel)) + } + req.ConsistencyLevel = commonpb.ConsistencyLevel(level) + } else { + req.UseDefaultConsistency = true } c.Set(ContextRequest, req) @@ -951,7 +940,8 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN if err != nil { return nil, err } - searchParams, err := generateSearchParams(ctx, c, httpReq.Params) + + searchParams, err := generateSearchParams(ctx, c, httpReq.SearchParams) if err != nil { return nil, err } @@ -959,7 +949,6 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) body, _ := c.Get(gin.BodyBytesKey) placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) if err != nil { @@ -1005,6 +994,15 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq Requests: []*milvuspb.SearchRequest{}, OutputFields: httpReq.OutputFields, } + if httpReq.ConsistencyLevel != "" { + level, ok := commonpb.ConsistencyLevel_value[httpReq.ConsistencyLevel] + if !ok { + return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", httpReq.ConsistencyLevel)) + } + req.ConsistencyLevel = commonpb.ConsistencyLevel(level) + } else { + req.UseDefaultConsistency = true + } c.Set(ContextRequest, req) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) @@ -1014,7 +1012,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams, err := generateSearchParams(ctx, c, subReq.Params) + searchParams, err := generateSearchParams(ctx, c, subReq.SearchParams) if err != nil { return nil, err } @@ -1022,7 +1020,6 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) @@ -1033,15 +1030,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq return nil, err } searchReq := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: subReq.Filter, - PlaceholderGroup: placeholderGroup, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - SearchParams: searchParams, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: subReq.Filter, + PlaceholderGroup: placeholderGroup, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, } req.Requests = append(req.Requests, searchReq) } diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 1728dac1017fc..4441245095af8 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1349,7 +1349,7 @@ func TestSearchV2(t *testing.T) { Schema: generateCollectionSchema(schemapb.DataType_Int64), ShardsNum: ShardNumDefault, Status: &StatusSuccess, - }, nil).Times(12) + }, nil).Times(11) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{ TopK: int64(3), OutputFields: outputFields, @@ -1398,7 +1398,7 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1406,8 +1406,8 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), - errMsg: "can only accept json format request, error: invalid search params", + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignore_growing": "true"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignore_growing of type bool", errCode: 1801, // ErrIncorrectParameterFormat }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1423,7 +1423,7 @@ func TestSearchV2(t *testing.T) { queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), - errMsg: "can only accept json format request, error: Mismatch type float32 with value string \"at index 8: mismatched type with value\\n\\n\\t[\\\"0.1\\\", \\\"0.2\\\"]\\n\\t........^.....\\n\": invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]", + errMsg: "can only accept json format request, error: Mismatch type float32 with value number \"at index 2: mismatched type with value\\n\\n\\t[\\\"0.1\\\", \\\"0.2\\\"]\\n\\t..^...........\\n\": invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]", errCode: 1801, }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1465,7 +1465,7 @@ func TestSearchV2(t *testing.T) { queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), - errMsg: "can only accept json format request, error: Mismatch type uint8 with value number \"at index 7: mismatched type with value\\n\\n\\t[[0.1, 0.2]]\\n\\t.......^....\\n\": invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]", + errMsg: "can only accept json format request, error: Mismatch type []uint8 with value array \"at index 1: mismatched type with value\\n\\n\\t[[0.1, 0.2]]\\n\\t.^..........\\n\": invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]", errCode: 1801, }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1529,6 +1529,12 @@ func TestSearchV2(t *testing.T) { path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}", + errCode: 1801, // ErrIncorrectParameterFormat + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index b292b73a82e65..650924f2ac337 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -141,18 +141,28 @@ type CollectionDataReq struct { func (req *CollectionDataReq) GetDbName() string { return req.DbName } +type searchParams struct { + // not use metricType any more, just for compatibility + MetricType string `json:"metricType"` + Params map[string]interface{} `json:"params"` + IgnoreGrowing bool `json:"ignore_growing"` +} + type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - Params map[string]float64 `json:"params"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + // not use Params any more, just for compatibility + Params map[string]float64 `json:"params"` } func (req *SearchReqV2) GetDbName() string { return req.DbName } @@ -163,25 +173,25 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - IgnoreGrowing bool `json:"ignoreGrowing"` - Params map[string]float64 `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"searchParams"` } type HybridSearchReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - Search []SubSearchReq `json:"search"` - Rerank Rand `json:"rerank"` - Limit int32 `json:"limit"` - OutputFields []string `json:"outputFields"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Search []SubSearchReq `json:"search"` + Rerank Rand `json:"rerank"` + Limit int32 `json:"limit"` + OutputFields []string `json:"outputFields"` + ConsistencyLevel string `json:"consistencyLevel"` } func (req *HybridSearchReq) GetDbName() string { return req.DbName } diff --git a/pkg/common/common.go b/pkg/common/common.go index 94f361da4a316..a1c81d098908c 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -133,6 +133,8 @@ const ( IsSparseKey = "is_sparse" AutoIndexName = "AUTOINDEX" BitmapCardinalityLimitKey = "bitmap_cardinality_limit" + IgnoreGrowing = "ignore_growing" + ConsistencyLevel = "consistency_level" ) // Collection properties key