Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Part 3] Concurrent segment search bug in Sorting #808

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.TotalHits.Relation;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults;
Expand All @@ -56,14 +57,14 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
private static final float boost_factor = 1f;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private final TopDocsMerger topDocsMerger;
@Nullable
private final FieldDoc after;
private static final float boost_factor = 1f;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -74,7 +75,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
if (searchContext.sort() != null) {
Expand All @@ -98,7 +98,6 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
Expand All @@ -107,7 +106,6 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
Expand Down Expand Up @@ -150,66 +148,62 @@ private Collector getHybridQueryCollector() {
*/
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();
final List<HybridTopFieldDocSortCollector> hybridSortedTopDocCollectors = new ArrayList<>();
// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
// format specific for hybrid search query: start, sub-query-delimiter, scores, stop
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub);
} else if (sub instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector);
} else if (collector instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) collector);
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) ((FilteredCollector) collector).getCollector());
}
final List<Collector> hybridSearchCollectors = getHybridSearchCollectors(collectors);
if (hybridSearchCollectors.isEmpty()) {
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}
return reduceSearchResults(getSearchResults(hybridSearchCollectors));
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
private List<ReduceableSearchResult> getSearchResults(final List<Collector> hybridSearchCollectors) {
List<ReduceableSearchResult> results = new ArrayList<>();
DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats);
for (Collector collector : hybridSearchCollectors) {
TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats);
results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats));
}
return results;
}

// TODO: Cater the fix for the Bug https://github.com/opensearch-project/neural-search/issues/799
if (!hybridSortedTopDocCollectors.isEmpty()) {
HybridTopFieldDocSortCollector hybridSortedTopScoreDocCollector = hybridSortedTopDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));

List<TopFieldDocs> topFieldDocs = hybridSortedTopScoreDocCollector.topDocs();
long maxTotalHits = hybridSortedTopScoreDocCollector.getTotalHits();
float maxScore = hybridSortedTopScoreDocCollector.getMaxScore();

TopDocs newTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, isSingleShard, maxTotalHits),
private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final Collector collector, final DocValueFormat[] docValueFormats) {
float maxScore;
TopDocs newTopDocs;
if (docValueFormats != null) {
HybridTopFieldDocSortCollector hybridTopFieldDocSortCollector = (HybridTopFieldDocSortCollector) collector;
List<TopFieldDocs> topFieldDocs = hybridTopFieldDocSortCollector.topDocs();
maxScore = hybridTopFieldDocSortCollector.getMaxScore();
newTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, hybridTopFieldDocSortCollector.getTotalHits()),
topFieldDocs,
sortAndFormats.sort.getSort()
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
} else {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
HybridTopScoreDocCollector hybridTopScoreDocCollector = (HybridTopScoreDocCollector) collector;
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
maxScore = hybridTopScoreDocCollector.getMaxScore();
newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), topDocs);
}
return new TopDocsAndMaxScore(newTopDocs, maxScore);
}

throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
private List<Collector> getHybridSearchCollectors(final Collection<Collector> collectors) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
final List<Collector> hybridSearchCollectors = new ArrayList<>();
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector || sub instanceof HybridTopFieldDocSortCollector) {
hybridSearchCollectors.add(sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector || collector instanceof HybridTopFieldDocSortCollector) {
hybridSearchCollectors.add(collector);
} else if (collector instanceof FilteredCollector
&& (((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector
|| ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector)) {
hybridSearchCollectors.add(((FilteredCollector) collector).getCollector());
}
}
return hybridSearchCollectors;
}

private static void validateSortCriteria(SearchContext searchContext, boolean trackScores) {
Expand Down Expand Up @@ -302,10 +296,11 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<?> topDocs, final boolean isSingleShard, final long maxTotalHits) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<?> topDocs, final long maxTotalHits) {
final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;

if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
Expand Down Expand Up @@ -372,6 +367,54 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private void reduceCollectorResults(
final QuerySearchResult result,
final TopDocsAndMaxScore topDocsAndMaxScore,
final DocValueFormat[] docValueFormats
) {
// this is case of first collector, query result object doesn't have any top docs set, so we can
// just set new top docs without merge
// this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because
// getter throws exception in case topDocs is null
if (result.hasConsumedTopDocs()) {
result.topDocs(topDocsAndMaxScore, docValueFormats);
return;
}
// in this case top docs are already present in result, and we need to merge next result object with what we have.
// if collector doesn't have any hits we can just skip it and save some cycles by not doing merge
if (topDocsAndMaxScore.topDocs.totalHits.value == 0) {
return;
}
// we need to do actual merge because query result and current collector both have some score hits
TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
result.topDocs(getMergeTopDocsAndMaxScores(originalTotalDocsAndHits, topDocsAndMaxScore), docValueFormats);
}

private TopDocsAndMaxScore getMergeTopDocsAndMaxScores(
final TopDocsAndMaxScore originalTotalDocsAndHits,
final TopDocsAndMaxScore topDocsAndMaxScore
) {
if (sortAndFormats != null) {
return topDocsMerger.mergeFieldDocs(originalTotalDocsAndHits, topDocsAndMaxScore, sortAndFormats);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
} else {
return topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore);
}
}

/**
* For collection of search results, return a single one that has results from all individual result objects.
* @param results collection of search results
* @return single search result that represents all results as one object
*/
private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchResult> results) {
return (result) -> {
for (ReduceableSearchResult r : results) {
// call reduce for results of each single collector, this will update top docs in query result
r.reduce(result);
}
};
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand All @@ -382,7 +425,6 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager
public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
Expand All @@ -391,10 +433,10 @@ public HybridCollectorNonConcurrentManager(
super(
numHits,
hitsThresholdChecker,
isSingleShard,
trackTotalHitsUpTo,
sortAndFormats,
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
Expand All @@ -421,7 +463,6 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag
public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
Expand All @@ -430,10 +471,10 @@ public HybridCollectorConcurrentSearchManager(
super(
numHits,
hitsThresholdChecker,
isSingleShard,
trackTotalHitsUpTo,
sortAndFormats,
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import java.util.Comparator;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;

/**
* Comparator class that compares two field docs as per the sorting criteria
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryFieldDocComparator implements Comparator<FieldDoc> {
final SortField[] sortFields;
final FieldComparator<?>[] comparators;
final int[] reverseMul;
final Comparator<ScoreDoc> tieBreaker;

public HybridQueryFieldDocComparator(SortField[] sortFields, Comparator<ScoreDoc> tieBreaker) {
this.sortFields = sortFields;
this.tieBreaker = tieBreaker;
comparators = new FieldComparator[sortFields.length];
reverseMul = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
}
}

@Override
public int compare(final FieldDoc firstFD, final FieldDoc secondFD) {
for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
final FieldComparator comp = comparators[compIDX];

final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);

if (cmp != 0) {
return cmp;
}
}
return tieBreakCompare(firstFD, secondFD, tieBreaker);
}

private int tieBreakCompare(ScoreDoc firstDoc, ScoreDoc secondDoc, Comparator<ScoreDoc> tieBreaker) {
assert tieBreaker != null;
int value = tieBreaker.compare(firstDoc, secondDoc);
return value;
}
}
Loading
Loading