Skip to content

Commit

Permalink
Fix for missing HybridQuery results when concurrent segment search is…
Browse files Browse the repository at this point in the history
… enabled (#800)

* Adding merge logic for multiple collector result case

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Jun 25, 2024
1 parent 3e4a2ef commit 25d2e82
Show file tree
Hide file tree
Showing 17 changed files with 1,015 additions and 105 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.TotalHits.Relation;
import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

Expand All @@ -46,12 +48,12 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
private static final float boost_factor = 1f;
private final TopDocsMerger topDocsMerger;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -62,7 +64,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Expand All @@ -83,15 +84,13 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
Expand All @@ -118,6 +117,27 @@ public Collector newCollector() {
*/
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = getHybridScoreDocCollectors(collectors);
if (hybridTopScoreDocCollectors.isEmpty()) {
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
}

List<ReduceableSearchResult> results = new ArrayList<>();
DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats);
for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) {
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());

results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats, newTopDocs));
}
return reduceSearchResults(results);
}

private List<HybridTopScoreDocCollector> getHybridScoreDocCollectors(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();
// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
Expand All @@ -136,20 +156,7 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
}
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
return hybridTopScoreDocCollectors;
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
Expand Down Expand Up @@ -195,15 +202,10 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(
int trackTotalHitsUpTo,
final List<TopDocs> topDocs,
final boolean isSingleShard,
final long maxTotalHits
) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final long maxTotalHits) {
final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
Expand All @@ -215,6 +217,45 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private void reduceCollectorResults(
QuerySearchResult result,
TopDocsAndMaxScore topDocsAndMaxScore,
DocValueFormat[] docValueFormats,
TopDocs newTopDocs
) {
// this is case of first collector, query result object doesn't have any top docs set, so we can
// just set new top docs without merge
// this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because
// getter throws exception in case topDocs is null
if (result.hasConsumedTopDocs()) {
result.topDocs(topDocsAndMaxScore, docValueFormats);
return;
}
// in this case top docs are already present in result, and we need to merge next result object with what we have.
// if collector doesn't have any hits we can just skip it and save some cycles by not doing merge
if (newTopDocs.totalHits.value == 0) {
return;
}
// we need to do actual merge because query result and current collector both have some score hits
TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore);
result.topDocs(mergeTopDocsAndMaxScores, docValueFormats);
}

/**
* For collection of search results, return a single one that has results from all individual result objects.
* @param results collection of search results
* @return single search result that represents all results as one object
*/
private ReduceableSearchResult reduceSearchResults(List<ReduceableSearchResult> results) {
return (result) -> {
for (ReduceableSearchResult r : results) {
// call reduce for results of each single collector, this will update top docs in query result
r.reduce(result);
}
};
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand All @@ -225,12 +266,11 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager
public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}

Expand All @@ -255,12 +295,11 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag
public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.search.ScoreDoc;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement;

/**
* Merges two ScoreDoc arrays into one
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryScoreDocsMerger<T extends ScoreDoc> {

private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3;

/**
* Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects.
* Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from
* other query types.
* Logic is based on assumption that hits of every sub-query are sorted by score.
* Method returns new object and doesn't mutate original ScoreDocs arrays.
* @param sourceScoreDocs original score docs from query result
* @param newScoreDocs new score docs that we need to merge into existing scores
* @return merged array of ScoreDocs objects
*/
public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator<T> comparator) {
if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC
|| Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) {
throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements");
}
// we overshoot and preallocate more than we need - length of both top docs combined.
// we will take only portion of the array at the end
List<T> mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length);
int sourcePointer = 0;
// mark beginning of hybrid query results by start element
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
// new pointer is set to 1 as we don't care about it start-stop element
int newPointer = 1;

while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) {
// every iteration is for results of one sub-query
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
newPointer++;
// simplest case when both arrays have results for sub-query
while (sourcePointer < sourceScoreDocs.length
&& isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])
&& newPointer < newScoreDocs.length
&& isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
} else {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// at least one object got exhausted at this point, now merge all elements from object that's left
while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
}
while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// mark end of hybrid query results by end element
mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]);
return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;

import java.util.Comparator;
import java.util.Objects;

/**
* Utility class for merging TopDocs and MaxScore across multiple search queries
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class TopDocsMerger {

private final HybridQueryScoreDocsMerger<ScoreDoc> scoreDocsMerger;
@VisibleForTesting
protected static final Comparator<ScoreDoc> SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score);
/**
* Uses hybrid query score docs merger to merge internal score docs
*/
static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>());

/**
* Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object.
* @param source TopDocsAndMaxScore for the original query
* @param newTopDocs TopDocsAndMaxScore for the new query
* @return merged TopDocsAndMaxScore object
*/
public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) {
return source;
}
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(
source.topDocs.scoreDocs,
newTopDocs.topDocs.scoreDocs,
SCORE_DOC_BY_SCORE_COMPARATOR
);
TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs);
TopDocsAndMaxScore result = new TopDocsAndMaxScore(
new TopDocs(mergedTotalHits, mergedScoreDocs),
Math.max(source.maxScore, newTopDocs.maxScore)
);
return result;
}

private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
// merged value is a lower bound - if both are equal_to than merged will also be equal_to,
// otherwise assign greater_than_or_equal
TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|| newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,28 @@ public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) {
public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0;
}

/**
* Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores
* @param scoreDoc score doc object to check on
* @return true if it is a special element
*/
public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc);
}

/**
* Checking if passed scoreDocs object is a document score element
* @param scoreDoc score doc object to check on
* @return true if element has score
*/
public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return !isHybridQuerySpecialElement(scoreDoc);
}
}
Loading

0 comments on commit 25d2e82

Please sign in to comment.