From cb49ed910e802672b96ad4165f0840998eaa4c70 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Sat, 14 Sep 2024 19:33:47 +0800 Subject: [PATCH] enhance: add search params in search request in restful(#36304) Signed-off-by: lixinguo --- internal/datanode/importv2/pool_test.go | 3 +- .../proxy/httpserver/handler_v2.go | 98 +++++++++---------- .../proxy/httpserver/handler_v2_test.go | 43 +++++++- .../proxy/httpserver/request_v2.go | 64 +++++++----- .../distributed/proxy/httpserver/utils.go | 12 +++ .../proxy/httpserver/utils_test.go | 13 +++ pkg/common/common.go | 2 + .../testcases/test_vector_operations.py | 3 + 8 files changed, 153 insertions(+), 85 deletions(-) diff --git a/internal/datanode/importv2/pool_test.go b/internal/datanode/importv2/pool_test.go index 06873c6d31ae5..4449a5031c812 100644 --- a/internal/datanode/importv2/pool_test.go +++ b/internal/datanode/importv2/pool_test.go @@ -20,9 +20,10 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestResizePools(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 24092e75d5ca9..b8852ac485aab 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -944,61 +944,48 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { - params := map[string]interface{}{ // auto generated mapping - "level": int(commonpb.ConsistencyLevel_Bounded), - } - if reqParams != nil { - radius, radiusOk := reqParams[ParamRadius] - rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter] - if rangeFilterOk { - if !radiusOk { - log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), - HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", - }) - return nil, merr.ErrIncorrectParameterFormat - } - params[ParamRangeFilter] = rangeFilter - } - if radiusOk { - params[ParamRadius] = radius - } - } - bs, _ := json.Marshal(params) - searchParams := []*commonpb.KeyValuePair{ - {Key: Params, Value: string(bs)}, - } - return searchParams, nil +func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair { + var searchParams []*commonpb.KeyValuePair + bs, _ := json.Marshal(reqSearchParams.Params) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)}) + // need to exposure ParamRoundDecimal in req? + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) + return searchParams } func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, } - c.Set(ContextRequest, req) - - collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(), + }) return nil, err } - searchParams, err := generateSearchParams(ctx, c, httpReq.Params) + c.Set(ContextRequest, req) + + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err } + + searchParams := generateSearchParams(ctx, c, httpReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) body, _ := c.Get(gin.BodyBytesKey) placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) if err != nil { @@ -1044,6 +1031,16 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq Requests: []*milvuspb.SearchRequest{}, OutputFields: httpReq.OutputFields, } + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(), + }) + return nil, err + } c.Set(ContextRequest, req) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) @@ -1053,15 +1050,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams, err := generateSearchParams(ctx, c, subReq.Params) - if err != nil { - return nil, err - } + searchParams := generateSearchParams(ctx, c, subReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) @@ -1072,15 +1065,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq return nil, err } searchReq := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: subReq.Filter, - PlaceholderGroup: placeholderGroup, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - SearchParams: searchParams, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: subReq.Filter, + PlaceholderGroup: placeholderGroup, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, } req.Requests = append(req.Requests, searchReq) } diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 96aaa4289f9ed..b53c997edc4fd 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1424,7 +1424,7 @@ func TestSearchV2(t *testing.T) { Schema: generateCollectionSchema(schemapb.DataType_Int64), ShardsNum: ShardNumDefault, Status: &StatusSuccess, - }, nil).Times(12) + }, nil).Times(11) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{ TopK: int64(3), OutputFields: outputFields, @@ -1465,6 +1465,12 @@ func TestSearchV2(t *testing.T) { Status: &StatusSuccess, }, nil).Times(10) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() testEngine := initHTTPServerV2(mp, false) queryTestCases := []requestBodyTestCase{} queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1473,7 +1479,7 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1481,8 +1487,8 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), - errMsg: "can only accept json format request, error: invalid search params", + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool", errCode: 1801, // ErrIncorrectParameterFormat }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1556,6 +1562,17 @@ func TestSearchV2(t *testing.T) { `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter", + errCode: 1100, // ErrParameterInvalid + }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: AdvancedSearchAction, requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + @@ -1604,6 +1621,24 @@ func TestSearchV2(t *testing.T) { path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}", + errCode: 1801, // ErrIncorrectParameterFormat + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`), + errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter", + errCode: 1100, // ErrParameterInvalid + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index fc7d82dc1fbd4..b8a77f6759bfe 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -141,18 +141,28 @@ type CollectionDataReq struct { func (req *CollectionDataReq) GetDbName() string { return req.DbName } +type searchParams struct { + // not use metricType any more, just for compatibility + MetricType string `json:"metricType"` + Params map[string]interface{} `json:"params"` + IgnoreGrowing bool `json:"ignoreGrowing"` +} + type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - Params map[string]float64 `json:"params"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + // not use Params any more, just for compatibility + Params map[string]float64 `json:"params"` } func (req *SearchReqV2) GetDbName() string { return req.DbName } @@ -163,25 +173,25 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - IgnoreGrowing bool `json:"ignoreGrowing"` - Params map[string]float64 `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"searchParams"` } type HybridSearchReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - Search []SubSearchReq `json:"search"` - Rerank Rand `json:"rerank"` - Limit int32 `json:"limit"` - OutputFields []string `json:"outputFields"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Search []SubSearchReq `json:"search"` + Rerank Rand `json:"rerank"` + Limit int32 `json:"limit"` + OutputFields []string `json:"outputFields"` + ConsistencyLevel string `json:"consistencyLevel"` } func (req *HybridSearchReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index e131e53dd95b3..51be89bf5c338 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1314,3 +1314,15 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro } return params, nil } + +func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) { + if reqConsistencyLevel != "" { + level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel] + if !ok { + return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel)) + } + return commonpb.ConsistencyLevel(level), false, nil + } + // ConsistencyLevel_Bounded default in PyMilvus + return commonpb.ConsistencyLevel_Bounded, true, nil +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index c9a4a1f42b38b..90a8de362abc8 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1406,3 +1406,16 @@ func TestConvertToExtraParams(t *testing.T) { } } } + +func TestConvertConsistencyLevel(t *testing.T) { + consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded) + assert.Equal(t, true, useDefaultConsistency) + consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong) + assert.Equal(t, false, useDefaultConsistency) + _, _, err = convertConsistencyLevel("test") + assert.NotNil(t, err) +} diff --git a/pkg/common/common.go b/pkg/common/common.go index 085dc6939dfb9..09f94b3df0f55 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -128,6 +128,8 @@ const ( BitmapCardinalityLimitKey = "bitmap_cardinality_limit" IsSparseKey = "is_sparse" AutoIndexName = "AUTOINDEX" + IgnoreGrowing = "ignore_growing" + ConsistencyLevel = "consistency_level" ) // Collection properties key diff --git a/tests/restful_client_v2/testcases/test_vector_operations.py b/tests/restful_client_v2/testcases/test_vector_operations.py index 98a935f2b613b..988cf945204c8 100644 --- a/tests/restful_client_v2/testcases/test_vector_operations.py +++ b/tests/restful_client_v2/testcases/test_vector_operations.py @@ -926,6 +926,7 @@ class TestSearchVector(TestBase): @pytest.mark.parametrize("auto_id", [True]) @pytest.mark.parametrize("is_partition_key", [True]) @pytest.mark.parametrize("enable_dynamic_schema", [True]) + @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing") @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [16]) def test_search_vector_with_all_vector_datatype(self, nb, dim, insert_round, auto_id, @@ -1031,6 +1032,7 @@ def test_search_vector_with_all_vector_datatype(self, nb, dim, insert_round, aut @pytest.mark.parametrize("enable_dynamic_schema", [True]) @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [128]) + @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing") @pytest.mark.parametrize("nq", [1, 2]) def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id, is_partition_key, enable_dynamic_schema, nq): @@ -1225,6 +1227,7 @@ def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_r @pytest.mark.parametrize("enable_dynamic_schema", [True]) @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [128]) + @pytest.mark.skip(reason="behavior change;todo:@zhuwenxing") def test_search_vector_with_binary_vector_datatype(self, nb, dim, insert_round, auto_id, is_partition_key, enable_dynamic_schema): """