Skip to content

Commit

Permalink
fix: change pymilvus version for hybridsearch-groupby(#36407) (#36451)
Browse files Browse the repository at this point in the history
related: #36407

---------

Signed-off-by: MrPresent-Han <[email protected]>
Co-authored-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han and MrPresent-Han authored Sep 24, 2024
1 parent 98a917c commit d55d9d6
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 82 deletions.
150 changes: 73 additions & 77 deletions internal/proxy/search_reduce_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,95 +218,91 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
totalResCount += subSearchNqOffset[i][nq-1]
}

if subSearchNum == 1 && offset == 0 {
ret.Results = subSearchResultData[0]
} else {
var realTopK int64 = -1
var retSize int64
var realTopK int64 = -1
var retSize int64

maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)

j int64
groupByValMap = make(map[interface{}][]*groupReduceInfo)
skipOffsetMap = make(map[interface{}]bool)
groupByValList = make([]interface{}, limit)
groupByValIdx = 0
)

for j = 0; j < groupBound; {
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]

id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.GetScores()[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
if groupByVal == nil {
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
"there must be sth wrong on queryNode side")
}
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)

j int64
groupByValMap = make(map[interface{}][]*groupReduceInfo)
skipOffsetMap = make(map[interface{}]bool)
groupByValList = make([]interface{}, limit)
groupByValIdx = 0
)

for j = 0; j < groupBound; {
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]

id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.GetScores()[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
if groupByVal == nil {
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
"there must be sth wrong on queryNode side")
}

if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
skipOffsetMap[groupByVal] = true
// the first offset's group will be ignored
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
// skip when groupbyMap has been full and found new groupByVal
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
// skip when target group has been full
} else {
if len(groupByValMap[groupByVal]) == 0 {
groupByValList[groupByValIdx] = groupByVal
groupByValIdx++
}
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
subSearchIdx: subSearchIdx,
resultIdx: resultDataIdx, id: id, score: score,
})
j++
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
skipOffsetMap[groupByVal] = true
// the first offset's group will be ignored
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
// skip when groupbyMap has been full and found new groupByVal
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
// skip when target group has been full
} else {
if len(groupByValMap[groupByVal]) == 0 {
groupByValList[groupByValIdx] = groupByVal
groupByValIdx++
}

cursors[subSearchIdx]++
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
subSearchIdx: subSearchIdx,
resultIdx: resultDataIdx, id: id, score: score,
})
j++
}

// assemble all eligible values in group
// values in groupByValList is sorted by the highest score in each group
for _, groupVal := range groupByValList {
if groupVal != nil {
groupEntities := groupByValMap[groupVal]
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
cursors[subSearchIdx]++
}

// assemble all eligible values in group
// values in groupByValList is sorted by the highest score in each group
for _, groupVal := range groupByValList {
if groupVal != nil {
groupEntities := groupByValMap[groupVal]
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
}
}
}

if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)

// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
Expand Down
4 changes: 2 additions & 2 deletions tests/python_client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ allure-pytest==2.7.0
pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.5.0
pymilvus==2.5.0rc80
pymilvus[bulk_writer]==2.5.0rc80
pymilvus==2.5.0rc81
pymilvus[bulk_writer]==2.5.0rc81
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient
Expand Down
3 changes: 0 additions & 3 deletions tests/python_client/testcases/test_mix_scenes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,6 @@ def prepare_data(self):
self.collection_wrap.load()

@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue #36407")
@pytest.mark.parametrize("group_by_field", [DataType.VARCHAR.name, "varchar_with_index"])
def test_search_group_size(self, group_by_field):
"""
Expand Down Expand Up @@ -1308,7 +1307,6 @@ def test_search_group_size(self, group_by_field):
assert len(set(group_values)) == limit

@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue #36407")
def test_hybrid_search_group_size(self):
"""
hybrid search group by on 3 different float vector fields with group by varchar field with group size
Expand Down Expand Up @@ -1360,7 +1358,6 @@ def test_hybrid_search_group_size(self):
group_distances = [res[i][l + 1].distance]

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #36407")
def test_hybrid_search_group_by(self):
"""
verify hybrid search group by works with different Rankers
Expand Down

0 comments on commit d55d9d6

Please sign in to comment.