Skip to content

Commit

Permalink
enhance: add search params in search request in restful
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Sep 17, 2024
1 parent 526a672 commit 41d8474
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 81 deletions.
3 changes: 1 addition & 2 deletions internal/datacoord/import_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ import (
"sort"
"time"

"github.com/milvus-io/milvus/internal/proto/indexpb"

"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"

"github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/importutilv2"
Expand Down
88 changes: 42 additions & 46 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,61 +905,50 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}

func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
params := map[string]interface{}{ // auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
}
if reqParams != nil {
radius, radiusOk := reqParams[ParamRadius]
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
if rangeFilterOk {
if !radiusOk {
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params",
})
return nil, merr.ErrIncorrectParameterFormat
}
params[ParamRangeFilter] = rangeFilter
}
if radiusOk {
params[ParamRadius] = radius
}
}
bs, _ := json.Marshal(params)
searchParams := []*commonpb.KeyValuePair{
{Key: Params, Value: string(bs)},
}
func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) ([]*commonpb.KeyValuePair, error) {
var searchParams []*commonpb.KeyValuePair
bs, _ := json.Marshal(reqSearchParams.Params)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams, nil
}

func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
req := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
}
if httpReq.ConsistencyLevel != "" {
level, ok := commonpb.ConsistencyLevel_value[httpReq.ConsistencyLevel]
if !ok {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", httpReq.ConsistencyLevel))
}
req.ConsistencyLevel = commonpb.ConsistencyLevel(level)
} else {
req.UseDefaultConsistency = true
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
}
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)

searchParams, err := generateSearchParams(ctx, c, httpReq.SearchParams)
if err != nil {
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
body, _ := c.Get(gin.BodyBytesKey)
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
if err != nil {
Expand Down Expand Up @@ -1005,6 +994,15 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
Requests: []*milvuspb.SearchRequest{},
OutputFields: httpReq.OutputFields,
}
if httpReq.ConsistencyLevel != "" {
level, ok := commonpb.ConsistencyLevel_value[httpReq.ConsistencyLevel]
if !ok {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", httpReq.ConsistencyLevel))
}
req.ConsistencyLevel = commonpb.ConsistencyLevel(level)
} else {
req.UseDefaultConsistency = true
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
Expand All @@ -1014,15 +1012,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
searchParams, err := generateSearchParams(ctx, c, subReq.SearchParams)
if err != nil {
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
Expand All @@ -1033,15 +1030,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
return nil, err
}
searchReq := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
}
req.Requests = append(req.Requests, searchReq)
}
Expand Down
18 changes: 12 additions & 6 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ func TestSearchV2(t *testing.T) {
Schema: generateCollectionSchema(schemapb.DataType_Int64),
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(12)
}, nil).Times(11)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
TopK: int64(3),
OutputFields: outputFields,
Expand Down Expand Up @@ -1398,16 +1398,16 @@ func TestSearchV2(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
errMsg: "can only accept json format request, error: invalid search params",
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignore_growing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignore_growing of type bool",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand All @@ -1423,7 +1423,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "can only accept json format request, error: Mismatch type float32 with value string \"at index 8: mismatched type with value\\n\\n\\t[\\\"0.1\\\", \\\"0.2\\\"]\\n\\t........^.....\\n\": invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]",
errMsg: "can only accept json format request, error: Mismatch type float32 with value number \"at index 2: mismatched type with value\\n\\n\\t[\\\"0.1\\\", \\\"0.2\\\"]\\n\\t..^...........\\n\": invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1465,7 +1465,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "can only accept json format request, error: Mismatch type uint8 with value number \"at index 7: mismatched type with value\\n\\n\\t[[0.1, 0.2]]\\n\\t.......^....\\n\": invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]",
errMsg: "can only accept json format request, error: Mismatch type []uint8 with value array \"at index 1: mismatched type with value\\n\\n\\t[[0.1, 0.2]]\\n\\t.^..........\\n\": invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1529,6 +1529,12 @@ func TestSearchV2(t *testing.T) {
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
errCode: 1801, // ErrIncorrectParameterFormat
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
Expand Down
64 changes: 37 additions & 27 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,28 @@ type CollectionDataReq struct {

func (req *CollectionDataReq) GetDbName() string { return req.DbName }

type searchParams struct {
// not use metricType any more, just for compatibility
MetricType string `json:"metricType"`
Params map[string]interface{} `json:"params"`
IgnoreGrowing bool `json:"ignore_growing"`
}

type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Params map[string]float64 `json:"params"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
// not use Params any more, just for compatibility
Params map[string]float64 `json:"params"`
}

func (req *SearchReqV2) GetDbName() string { return req.DbName }
Expand All @@ -163,25 +173,25 @@ type Rand struct {
}

type SubSearchReq struct {
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
IgnoreGrowing bool `json:"ignoreGrowing"`
Params map[string]float64 `json:"params"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"searchParams"`
}

type HybridSearchReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"`
}

func (req *HybridSearchReq) GetDbName() string { return req.DbName }
Expand Down
2 changes: 2 additions & 0 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ const (
IsSparseKey = "is_sparse"
AutoIndexName = "AUTOINDEX"
BitmapCardinalityLimitKey = "bitmap_cardinality_limit"
IgnoreGrowing = "ignore_growing"
ConsistencyLevel = "consistency_level"
)

// Collection properties key
Expand Down

0 comments on commit 41d8474

Please sign in to comment.