diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index 7a547f7f1f792..15029f1e01256 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -218,95 +218,91 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData totalResCount += subSearchNqOffset[i][nq-1] } - if subSearchNum == 1 && offset == 0 { - ret.Results = subSearchResultData[0] - } else { - var realTopK int64 = -1 - var retSize int64 + var realTopK int64 = -1 + var retSize int64 - maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - // reducing nq * topk results - for i := int64(0); i < nq; i++ { - var ( - // cursor of current data of each subSearch for merging the j-th data of TopK. - // sum(cursors) == j - cursors = make([]int64, subSearchNum) - - j int64 - groupByValMap = make(map[interface{}][]*groupReduceInfo) - skipOffsetMap = make(map[interface{}]bool) - groupByValList = make([]interface{}, limit) - groupByValIdx = 0 - ) - - for j = 0; j < groupBound; { - subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) - if subSearchIdx == -1 { - break - } - subSearchRes := subSearchResultData[subSearchIdx] - - id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) - score := subSearchRes.GetScores()[resultDataIdx] - groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx)) - if groupByVal == nil { - return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," + - "there must be sth wrong on queryNode side") - } + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + // reducing nq * topk results + for i := int64(0); i < nq; i++ { + var ( + // cursor of current data of each subSearch for merging the j-th data of TopK. + // sum(cursors) == j + cursors = make([]int64, subSearchNum) + + j int64 + groupByValMap = make(map[interface{}][]*groupReduceInfo) + skipOffsetMap = make(map[interface{}]bool) + groupByValList = make([]interface{}, limit) + groupByValIdx = 0 + ) + + for j = 0; j < groupBound; { + subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) + if subSearchIdx == -1 { + break + } + subSearchRes := subSearchResultData[subSearchIdx] + + id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) + score := subSearchRes.GetScores()[resultDataIdx] + groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx)) + if groupByVal == nil { + return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," + + "there must be sth wrong on queryNode side") + } - if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] { - skipOffsetMap[groupByVal] = true - // the first offset's group will be ignored - } else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit { - // skip when groupbyMap has been full and found new groupByVal - } else if int64(len(groupByValMap[groupByVal])) >= groupSize { - // skip when target group has been full - } else { - if len(groupByValMap[groupByVal]) == 0 { - groupByValList[groupByValIdx] = groupByVal - groupByValIdx++ - } - groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{ - subSearchIdx: subSearchIdx, - resultIdx: resultDataIdx, id: id, score: score, - }) - j++ + if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] { + skipOffsetMap[groupByVal] = true + // the first offset's group will be ignored + } else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit { + // skip when groupbyMap has been full and found new groupByVal + } else if int64(len(groupByValMap[groupByVal])) >= groupSize { + // skip when target group has been full + } else { + if len(groupByValMap[groupByVal]) == 0 { + groupByValList[groupByValIdx] = groupByVal + groupByValIdx++ } - - cursors[subSearchIdx]++ + groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{ + subSearchIdx: subSearchIdx, + resultIdx: resultDataIdx, id: id, score: score, + }) + j++ } - // assemble all eligible values in group - // values in groupByValList is sorted by the highest score in each group - for _, groupVal := range groupByValList { - if groupVal != nil { - groupEntities := groupByValMap[groupVal] - for _, groupEntity := range groupEntities { - subResData := subSearchResultData[groupEntity.subSearchIdx] - retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx) - typeutil.AppendPKs(ret.Results.Ids, groupEntity.id) - ret.Results.Scores = append(ret.Results.Scores, groupEntity.score) - if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil { - log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) - return ret, err - } + cursors[subSearchIdx]++ + } + + // assemble all eligible values in group + // values in groupByValList is sorted by the highest score in each group + for _, groupVal := range groupByValList { + if groupVal != nil { + groupEntities := groupByValMap[groupVal] + for _, groupEntity := range groupEntities { + subResData := subSearchResultData[groupEntity.subSearchIdx] + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx) + typeutil.AppendPKs(ret.Results.Ids, groupEntity.id) + ret.Results.Scores = append(ret.Results.Scores, groupEntity.score) + if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil { + log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) + return ret, err } } } + } - if realTopK != -1 && realTopK != j { - log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) - } - realTopK = j - ret.Results.Topks = append(ret.Results.Topks, realTopK) + if realTopK != -1 && realTopK != j { + log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + } + realTopK = j + ret.Results.Topks = append(ret.Results.Topks, realTopK) - // limit search result to avoid oom - if retSize > maxOutputSize { - return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) - } + // limit search result to avoid oom + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } - ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query } + ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query if !metric.PositivelyRelated(metricType) { for k := range ret.Results.Scores { ret.Results.Scores[k] *= -1 diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 5bc61b6224c43..3e2af12e12590 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -12,8 +12,8 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 -pymilvus==2.5.0rc80 -pymilvus[bulk_writer]==2.5.0rc80 +pymilvus==2.5.0rc81 +pymilvus[bulk_writer]==2.5.0rc81 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient diff --git a/tests/python_client/testcases/test_mix_scenes.py b/tests/python_client/testcases/test_mix_scenes.py index b578bbad31d0a..8ce02adf00b5c 100644 --- a/tests/python_client/testcases/test_mix_scenes.py +++ b/tests/python_client/testcases/test_mix_scenes.py @@ -1266,7 +1266,6 @@ def prepare_data(self): self.collection_wrap.load() @pytest.mark.tags(CaseLabel.L0) - @pytest.mark.xfail(reason="issue #36407") @pytest.mark.parametrize("group_by_field", [DataType.VARCHAR.name, "varchar_with_index"]) def test_search_group_size(self, group_by_field): """ @@ -1308,7 +1307,6 @@ def test_search_group_size(self, group_by_field): assert len(set(group_values)) == limit @pytest.mark.tags(CaseLabel.L0) - @pytest.mark.xfail(reason="issue #36407") def test_hybrid_search_group_size(self): """ hybrid search group by on 3 different float vector fields with group by varchar field with group size @@ -1360,7 +1358,6 @@ def test_hybrid_search_group_size(self): group_distances = [res[i][l + 1].distance] @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.xfail(reason="issue #36407") def test_hybrid_search_group_by(self): """ verify hybrid search group by works with different Rankers