Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix for missing HybridQuery results when concurrent segment search is enabled #804

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.TotalHits.Relation;
import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

Expand All @@ -46,12 +48,12 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
private static final float boost_factor = 1f;
private final TopDocsMerger topDocsMerger;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -62,7 +64,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Expand All @@ -83,15 +84,13 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight
Expand All @@ -118,6 +117,27 @@ public Collector newCollector() {
*/
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = getHybridScoreDocCollectors(collectors);
if (hybridTopScoreDocCollectors.isEmpty()) {
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
}

List<ReduceableSearchResult> results = new ArrayList<>();
DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats);
for (HybridTopScoreDocCollector hybridTopScoreDocCollector : hybridTopScoreDocCollectors) {
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());

results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats, newTopDocs));
}
return reduceSearchResults(results);
}

private List<HybridTopScoreDocCollector> getHybridScoreDocCollectors(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();
// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
Expand All @@ -136,20 +156,7 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
}
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
return hybridTopScoreDocCollectors;
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
Expand Down Expand Up @@ -195,15 +202,10 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(
int trackTotalHitsUpTo,
final List<TopDocs> topDocs,
final boolean isSingleShard,
final long maxTotalHits
) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final long maxTotalHits) {
final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
Expand All @@ -215,6 +217,45 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private void reduceCollectorResults(
QuerySearchResult result,
TopDocsAndMaxScore topDocsAndMaxScore,
DocValueFormat[] docValueFormats,
TopDocs newTopDocs
) {
// this is case of first collector, query result object doesn't have any top docs set, so we can
// just set new top docs without merge
// this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because
// getter throws exception in case topDocs is null
if (result.hasConsumedTopDocs()) {
result.topDocs(topDocsAndMaxScore, docValueFormats);
return;
}
// in this case top docs are already present in result, and we need to merge next result object with what we have.
// if collector doesn't have any hits we can just skip it and save some cycles by not doing merge
if (newTopDocs.totalHits.value == 0) {
return;
}
// we need to do actual merge because query result and current collector both have some score hits
TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore);
result.topDocs(mergeTopDocsAndMaxScores, docValueFormats);
}

/**
* For collection of search results, return a single one that has results from all individual result objects.
* @param results collection of search results
* @return single search result that represents all results as one object
*/
private ReduceableSearchResult reduceSearchResults(List<ReduceableSearchResult> results) {
return (result) -> {
for (ReduceableSearchResult r : results) {
// call reduce for results of each single collector, this will update top docs in query result
r.reduce(result);
}
};
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand All @@ -225,12 +266,11 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager
public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}

Expand All @@ -255,12 +295,11 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag
public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
super(numHits, hitsThresholdChecker, trackTotalHitsUpTo, sortAndFormats, filteringWeight, TOP_DOCS_MERGER_TOP_SCORES);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.search.ScoreDoc;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement;

/**
* Merges two ScoreDoc arrays into one
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryScoreDocsMerger<T extends ScoreDoc> {

private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3;

/**
* Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects.
* Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from
* other query types.
* Logic is based on assumption that hits of every sub-query are sorted by score.
* Method returns new object and doesn't mutate original ScoreDocs arrays.
* @param sourceScoreDocs original score docs from query result
* @param newScoreDocs new score docs that we need to merge into existing scores
* @return merged array of ScoreDocs objects
*/
public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator<T> comparator) {
if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC
|| Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) {
throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements");
}
// we overshoot and preallocate more than we need - length of both top docs combined.
// we will take only portion of the array at the end
List<T> mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length);
int sourcePointer = 0;
// mark beginning of hybrid query results by start element
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
// new pointer is set to 1 as we don't care about it start-stop element
int newPointer = 1;

while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) {
// every iteration is for results of one sub-query
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
newPointer++;
// simplest case when both arrays have results for sub-query
while (sourcePointer < sourceScoreDocs.length
&& isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])
&& newPointer < newScoreDocs.length
&& isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
} else {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// at least one object got exhausted at this point, now merge all elements from object that's left
while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
}
while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// mark end of hybrid query results by end element
mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]);
return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;

import java.util.Comparator;
import java.util.Objects;

/**
* Utility class for merging TopDocs and MaxScore across multiple search queries
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class TopDocsMerger {

private final HybridQueryScoreDocsMerger<ScoreDoc> scoreDocsMerger;
@VisibleForTesting
protected static final Comparator<ScoreDoc> SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score);
/**
* Uses hybrid query score docs merger to merge internal score docs
*/
static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>());

/**
* Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object.
* @param source TopDocsAndMaxScore for the original query
* @param newTopDocs TopDocsAndMaxScore for the new query
* @return merged TopDocsAndMaxScore object
*/
public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) {
return source;
}
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(
source.topDocs.scoreDocs,
newTopDocs.topDocs.scoreDocs,
SCORE_DOC_BY_SCORE_COMPARATOR
);
TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs);
TopDocsAndMaxScore result = new TopDocsAndMaxScore(
new TopDocs(mergedTotalHits, mergedScoreDocs),
Math.max(source.maxScore, newTopDocs.maxScore)
);
return result;
}

private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) {
// merged value is a lower bound - if both are equal_to than merged will also be equal_to,
// otherwise assign greater_than_or_equal
TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|| newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,28 @@ public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) {
public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) {
return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0;
}

/**
* Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores
* @param scoreDoc score doc object to check on
* @return true if it is a special element
*/
public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc);
}

/**
* Checking if passed scoreDocs object is a document score element
* @param scoreDoc score doc object to check on
* @return true if element has score
*/
public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return !isHybridQuerySpecialElement(scoreDoc);
}
}
Loading
Loading