Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: not return err if consistencyLevel is not set to a valid value #36714

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,14 +905,14 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}

func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) ([]*commonpb.KeyValuePair, error) {
func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
bs, _ := json.Marshal(reqSearchParams.Params)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams, nil
return searchParams
}

func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
Expand All @@ -928,19 +928,22 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
// has already throw http in GetCollectionSchema if fails to get schema
return nil, err
}

searchParams, err := generateSearchParams(ctx, c, httpReq.SearchParams)
if err != nil {
return nil, err
}
searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
Expand Down Expand Up @@ -995,21 +998,24 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
c.Set(ContextRequest, req)

collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
// has already throw http in GetCollectionSchema if fails to get schema
return nil, err
}
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.SearchParams)
if err != nil {
return nil, err
}
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
Expand Down
33 changes: 31 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,12 @@ func TestSearchV2(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(10)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand All @@ -1406,8 +1412,8 @@ func TestSearchV2(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignore_growing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignore_growing of type bool",
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1481,6 +1487,17 @@ func TestSearchV2(t *testing.T) {
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
Expand Down Expand Up @@ -1535,6 +1552,18 @@ func TestSearchV2(t *testing.T) {
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ type searchParams struct {
// not use metricType any more, just for compatibility
MetricType string `json:"metricType"`
Params map[string]interface{} `json:"params"`
IgnoreGrowing bool `json:"ignore_growing"`
IgnoreGrowing bool `json:"ignoreGrowing"`
}

type SearchReqV2 struct {
Expand Down
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,6 @@ func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLe
}
return commonpb.ConsistencyLevel(level), false, nil
}
// ConsistencyLevel_Session default in PyMilvus
return commonpb.ConsistencyLevel_Session, true, nil
// ConsistencyLevel_Bounded default in PyMilvus
return commonpb.ConsistencyLevel_Bounded, true, nil
}
2 changes: 1 addition & 1 deletion internal/distributed/proxy/httpserver/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ func TestBuildQueryResps(t *testing.T) {
func TestConvertConsistencyLevel(t *testing.T) {
consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("")
assert.Equal(t, nil, err)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Session)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded)
assert.Equal(t, true, useDefaultConsistency)
consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong")
assert.Equal(t, nil, err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def test_search_vector_with_ignore_growing(self, ignore_growing):
"limit": limit,
"offset": 0,
"searchParams": {
"ignore_growing": ignore_growing
"ignoreGrowing": ignore_growing

}
}
Expand Down
Loading