Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: add search params in search request in restful(#36304) #37673

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/datanode/importv2/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/stretchr/testify/assert"
)

func TestResizePools(t *testing.T) {
Expand Down
98 changes: 45 additions & 53 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,61 +944,48 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}

func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
params := map[string]interface{}{ // auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
}
if reqParams != nil {
radius, radiusOk := reqParams[ParamRadius]
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
if rangeFilterOk {
if !radiusOk {
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params",
})
return nil, merr.ErrIncorrectParameterFormat
}
params[ParamRangeFilter] = rangeFilter
}
if radiusOk {
params[ParamRadius] = radius
}
}
bs, _ := json.Marshal(params)
searchParams := []*commonpb.KeyValuePair{
{Key: Params, Value: string(bs)},
}
return searchParams, nil
func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
bs, _ := json.Marshal(reqSearchParams.Params)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams
}

func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
req := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
}

searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
body, _ := c.Get(gin.BodyBytesKey)
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
if err != nil {
Expand Down Expand Up @@ -1044,6 +1031,16 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
Requests: []*milvuspb.SearchRequest{},
OutputFields: httpReq.OutputFields,
}
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
Expand All @@ -1053,15 +1050,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
if err != nil {
return nil, err
}
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
Expand All @@ -1072,15 +1065,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
return nil, err
}
searchReq := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
}
req.Requests = append(req.Requests, searchReq)
}
Expand Down
43 changes: 39 additions & 4 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ func TestSearchV2(t *testing.T) {
Schema: generateCollectionSchema(schemapb.DataType_Int64),
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(12)
}, nil).Times(11)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
TopK: int64(3),
OutputFields: outputFields,
Expand Down Expand Up @@ -1465,6 +1465,12 @@ func TestSearchV2(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(10)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand All @@ -1473,16 +1479,16 @@ func TestSearchV2(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
errMsg: "can only accept json format request, error: invalid search params",
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1556,6 +1562,17 @@ func TestSearchV2(t *testing.T) {
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
Expand Down Expand Up @@ -1604,6 +1621,24 @@ func TestSearchV2(t *testing.T) {
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
Expand Down
64 changes: 37 additions & 27 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,28 @@ type CollectionDataReq struct {

func (req *CollectionDataReq) GetDbName() string { return req.DbName }

type searchParams struct {
// not use metricType any more, just for compatibility
MetricType string `json:"metricType"`
Params map[string]interface{} `json:"params"`
IgnoreGrowing bool `json:"ignoreGrowing"`
}

type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Params map[string]float64 `json:"params"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
// not use Params any more, just for compatibility
Params map[string]float64 `json:"params"`
}

func (req *SearchReqV2) GetDbName() string { return req.DbName }
Expand All @@ -163,25 +173,25 @@ type Rand struct {
}

type SubSearchReq struct {
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
IgnoreGrowing bool `json:"ignoreGrowing"`
Params map[string]float64 `json:"params"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"searchParams"`
}

type HybridSearchReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"`
}

func (req *HybridSearchReq) GetDbName() string { return req.DbName }
Expand Down
12 changes: 12 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,15 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
}
return params, nil
}

func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) {
if reqConsistencyLevel != "" {
level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel]
if !ok {
return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel))
}
return commonpb.ConsistencyLevel(level), false, nil
}
// ConsistencyLevel_Bounded default in PyMilvus
return commonpb.ConsistencyLevel_Bounded, true, nil
}
13 changes: 13 additions & 0 deletions internal/distributed/proxy/httpserver/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1406,3 +1406,16 @@ func TestConvertToExtraParams(t *testing.T) {
}
}
}

func TestConvertConsistencyLevel(t *testing.T) {
consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("")
assert.Equal(t, nil, err)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded)
assert.Equal(t, true, useDefaultConsistency)
consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong")
assert.Equal(t, nil, err)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong)
assert.Equal(t, false, useDefaultConsistency)
_, _, err = convertConsistencyLevel("test")
assert.NotNil(t, err)
}
2 changes: 2 additions & 0 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ const (
BitmapCardinalityLimitKey = "bitmap_cardinality_limit"
IsSparseKey = "is_sparse"
AutoIndexName = "AUTOINDEX"
IgnoreGrowing = "ignore_growing"
ConsistencyLevel = "consistency_level"
)

// Collection properties key
Expand Down
9 changes: 6 additions & 3 deletions tests/restful_client_v2/base/testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=
batch_size = batch_size
batch = nb // batch_size
remainder = nb % batch_size
data = []

full_data = []
insert_ids = []
for i in range(batch):
nb = batch_size
Expand All @@ -116,6 +117,7 @@ def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=
assert rsp['code'] == 0
if return_insert_id:
insert_ids.extend(rsp['data']['insertIds'])
full_data.extend(data)
# insert remainder data
if remainder:
nb = remainder
Expand All @@ -128,10 +130,11 @@ def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=
assert rsp['code'] == 0
if return_insert_id:
insert_ids.extend(rsp['data']['insertIds'])
full_data.extend(data)
if return_insert_id:
return schema_payload, data, insert_ids
return schema_payload, full_data, insert_ids

return schema_payload, data
return schema_payload, full_data

def wait_collection_load_completed(self, name):
t0 = time.time()
Expand Down
Loading
Loading