Skip to content

Commit

Permalink
fix: unexpected shrotcut when proxy reducing(milvus-io#36407)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Sep 24, 2024
1 parent f004f8e commit b6e3b3b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 78 deletions.
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ type HybridSearchReq struct {
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"`
}
Expand Down
152 changes: 76 additions & 76 deletions internal/proxy/search_reduce_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,95 +218,95 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
totalResCount += subSearchNqOffset[i][nq-1]
}

if subSearchNum == 1 && offset == 0 {
/*if subSearchNum == 1 && offset == 0 {
ret.Results = subSearchResultData[0]
} else {
var realTopK int64 = -1
var retSize int64
} else {*/
var realTopK int64 = -1
var retSize int64

maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)

j int64
groupByValMap = make(map[interface{}][]*groupReduceInfo)
skipOffsetMap = make(map[interface{}]bool)
groupByValList = make([]interface{}, limit)
groupByValIdx = 0
)

for j = 0; j < groupBound; {
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]

id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.GetScores()[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
if groupByVal == nil {
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
"there must be sth wrong on queryNode side")
}
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)

j int64
groupByValMap = make(map[interface{}][]*groupReduceInfo)
skipOffsetMap = make(map[interface{}]bool)
groupByValList = make([]interface{}, limit)
groupByValIdx = 0
)

for j = 0; j < groupBound; {
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]

id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.GetScores()[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
if groupByVal == nil {
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
"there must be sth wrong on queryNode side")
}

if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
skipOffsetMap[groupByVal] = true
// the first offset's group will be ignored
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
// skip when groupbyMap has been full and found new groupByVal
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
// skip when target group has been full
} else {
if len(groupByValMap[groupByVal]) == 0 {
groupByValList[groupByValIdx] = groupByVal
groupByValIdx++
}
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
subSearchIdx: subSearchIdx,
resultIdx: resultDataIdx, id: id, score: score,
})
j++
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
skipOffsetMap[groupByVal] = true
// the first offset's group will be ignored
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
// skip when groupbyMap has been full and found new groupByVal
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
// skip when target group has been full
} else {
if len(groupByValMap[groupByVal]) == 0 {
groupByValList[groupByValIdx] = groupByVal
groupByValIdx++
}

cursors[subSearchIdx]++
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
subSearchIdx: subSearchIdx,
resultIdx: resultDataIdx, id: id, score: score,
})
j++
}

// assemble all eligible values in group
// values in groupByValList is sorted by the highest score in each group
for _, groupVal := range groupByValList {
if groupVal != nil {
groupEntities := groupByValMap[groupVal]
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
cursors[subSearchIdx]++
}

// assemble all eligible values in group
// values in groupByValList is sorted by the highest score in each group
for _, groupVal := range groupByValList {
if groupVal != nil {
groupEntities := groupByValMap[groupVal]
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
}
}
}

if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)

// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
//}
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
Expand Down

0 comments on commit b6e3b3b

Please sign in to comment.