Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Part 3] Concurrent segment search bug in Sorting #808

Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
return KNNQueryBuilder.builder()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to fix neural search package as the main branch is broken and is fixed in this PR. Therefore ignore this change as while cherrypicking on main it will be taken care off.

.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.collector;

import java.util.List;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.TopDocs;

/**
* Common interface class for Hybrid search collectors
*/
public interface HybridSearchCollector extends Collector {
/**
* @return List of topDocs which contains topDocs of individual subqueries.
*/
List<? extends TopDocs> topDocs();

/**
* @return count of total hits per shard
*/
int getTotalHits();

/**
* @return maxScore found on a shard
*/
float getMaxScore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -38,7 +37,7 @@
The individual query results are sorted as per the sort criteria sent in the search request.
*/
@Log4j2
public abstract class HybridTopFieldDocSortCollector implements Collector {
public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector {
private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final Sort sort;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import lombok.Getter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
Expand All @@ -30,7 +29,7 @@
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
*/
@Log4j2
public class HybridTopScoreDocCollector implements Collector {
public class HybridTopScoreDocCollector implements HybridSearchCollector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
Expand All @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.TotalHits.Relation;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults;
Expand All @@ -56,14 +58,14 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
private static final float boost_factor = 1f;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private final TopDocsMerger topDocsMerger;
@Nullable
private final FieldDoc after;
private static final float boost_factor = 1f;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -74,7 +76,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
if (searchContext.sort() != null) {
Expand All @@ -98,7 +99,6 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
Expand All @@ -107,7 +107,6 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
Expand Down Expand Up @@ -150,66 +149,62 @@ private Collector getHybridQueryCollector() {
*/
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();
final List<HybridTopFieldDocSortCollector> hybridSortedTopDocCollectors = new ArrayList<>();
// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
// format specific for hybrid search query: start, sub-query-delimiter, scores, stop
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub);
} else if (sub instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector);
} else if (collector instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) collector);
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) ((FilteredCollector) collector).getCollector());
}
final List<HybridSearchCollector> hybridSearchCollectors = getHybridSearchCollectors(collectors);
if (hybridSearchCollectors.isEmpty()) {
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}
return reduceSearchResults(getSearchResults(hybridSearchCollectors));
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
private List<ReduceableSearchResult> getSearchResults(final List<HybridSearchCollector> hybridSearchCollectors) {
List<ReduceableSearchResult> results = new ArrayList<>();
DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats);
for (HybridSearchCollector collector : hybridSearchCollectors) {
TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats);
results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats));
}
return results;
}

// TODO: Cater the fix for the Bug https://github.com/opensearch-project/neural-search/issues/799
if (!hybridSortedTopDocCollectors.isEmpty()) {
HybridTopFieldDocSortCollector hybridSortedTopScoreDocCollector = hybridSortedTopDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));

List<TopFieldDocs> topFieldDocs = hybridSortedTopScoreDocCollector.topDocs();
long maxTotalHits = hybridSortedTopScoreDocCollector.getTotalHits();
float maxScore = hybridSortedTopScoreDocCollector.getMaxScore();

TopDocs newTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, isSingleShard, maxTotalHits),
topFieldDocs,
private TopDocsAndMaxScore getTopDocsAndAndMaxScore(
final HybridSearchCollector hybridSearchCollector,
final DocValueFormat[] docValueFormats
) {
TopDocs newTopDocs;
List<? extends TopDocs> topDocs = hybridSearchCollector.topDocs();
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
if (docValueFormats != null) {
newTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()),
(List<TopFieldDocs>) topDocs,
sortAndFormats.sort.getSort()
);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
} else {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()),
(List<TopDocs>) topDocs
);
}
return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore());
}

throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
private List<HybridSearchCollector> getHybridSearchCollectors(final Collection<Collector> collectors) {
final List<HybridSearchCollector> hybridSearchCollectors = new ArrayList<>();
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector || sub instanceof HybridTopFieldDocSortCollector) {
hybridSearchCollectors.add((HybridSearchCollector) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector || collector instanceof HybridTopFieldDocSortCollector) {
hybridSearchCollectors.add((HybridSearchCollector) collector);
} else if (collector instanceof FilteredCollector
&& (((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector
|| ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector)) {
hybridSearchCollectors.add((HybridSearchCollector) ((FilteredCollector) collector).getCollector());
}
}
return hybridSearchCollectors;
}

private static void validateSortCriteria(SearchContext searchContext, boolean trackScores) {
Expand Down Expand Up @@ -302,10 +297,11 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<?> topDocs, final boolean isSingleShard, final long maxTotalHits) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<?> topDocs, final long maxTotalHits) {
final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? Relation.GREATER_THAN_OR_EQUAL_TO
: Relation.EQUAL_TO;

if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
Expand Down Expand Up @@ -372,6 +368,44 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private void reduceCollectorResults(
final QuerySearchResult result,
final TopDocsAndMaxScore topDocsAndMaxScore,
final DocValueFormat[] docValueFormats
) {
// this is case of first collector, query result object doesn't have any top docs set, so we can
// just set new top docs without merge
// this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because
// getter throws exception in case topDocs is null
if (result.hasConsumedTopDocs()) {
result.topDocs(topDocsAndMaxScore, docValueFormats);
return;
}
// in this case top docs are already present in result, and we need to merge next result object with what we have.
// if collector doesn't have any hits we can just skip it and save some cycles by not doing merge
if (topDocsAndMaxScore.topDocs.totalHits.value == 0) {
return;
}
// we need to do actual merge because query result and current collector both have some score hits
TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore);
result.topDocs(mergeTopDocsAndMaxScores, docValueFormats);
}

/**
* For collection of search results, return a single one that has results from all individual result objects.
* @param results collection of search results
* @return single search result that represents all results as one object
*/
private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchResult> results) {
return (result) -> {
for (ReduceableSearchResult r : results) {
// call reduce for results of each single collector, this will update top docs in query result
r.reduce(result);
}
};
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand All @@ -382,7 +416,6 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager
public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
Expand All @@ -391,10 +424,10 @@ public HybridCollectorNonConcurrentManager(
super(
numHits,
hitsThresholdChecker,
isSingleShard,
trackTotalHitsUpTo,
sortAndFormats,
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
Expand All @@ -421,7 +454,6 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag
public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
Expand All @@ -430,10 +462,10 @@ public HybridCollectorConcurrentSearchManager(
super(
numHits,
hitsThresholdChecker,
isSingleShard,
trackTotalHitsUpTo,
sortAndFormats,
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
);
}
Expand Down
Loading
Loading