diff --git a/CHANGELOG.md b/CHANGELOG.md index 528231b07..2ddf281f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 8846f6977..01d271cdd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -12,7 +12,6 @@ import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -77,12 +76,12 @@ public String toString(String field) { /** * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, * until the rewritten query is the same as the original query. - * @param reader + * @param indexSearcher * @return * @throws IOException */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (subQueries.isEmpty()) { return new MatchNoDocsQuery("empty HybridQuery"); } @@ -90,7 +89,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenSubQueries = new ArrayList<>(); for (Query subQuery : subQueries) { - Query rewrittenSub = subQuery.rewrite(reader); + Query rewrittenSub = subQuery.rewrite(indexSearcher); /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses perform better. For hybrid query we need to track progress of re-write for all sub-queries */ @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new HybridQuery(rewrittenSubQueries); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index aa4242c2e..46c087894 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List queries private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { List queries = queryBuilders.stream().map(qb -> { try { - return Rewriteable.rewrite(qb, context).toQuery(context); + return qb.rewrite(context).toQuery(context); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 5abfd0b5e..188a90209 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -18,10 +19,15 @@ import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import lombok.Getter; +import org.apache.lucene.util.PriorityQueue; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing @@ -40,12 +46,60 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; + private final DocIdSetIterator approximation; + HybridScorePropagator disjunctionBlockPropagator; + private final TwoPhase twoPhase; + public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + this(weight, subScorers, ScoreMode.TOP_SCORES); + } + + public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { super(weight); + // max this.subScorers = Collections.unmodifiableList(subScorers); + // custom subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); + // base this.subScorersPQ = initializeSubScorersPQ(); + // base + boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; + this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ); + // max + if (scoreMode == ScoreMode.TOP_SCORES) { + this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers); + } else { + this.disjunctionBlockPropagator = null; + } + // base + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : subScorersPQ) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + if (!hasApproximation) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores); + } + } + + @Override + public int advanceShallow(int target) throws IOException { + if (disjunctionBlockPropagator != null) { + return disjunctionBlockPropagator.advanceShallow(target); + } + return super.advanceShallow(target); } /** @@ -55,11 +109,14 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept */ @Override public float score() throws IOException { - DisiWrapper topList = subScorersPQ.topList(); + return score(getSubMatches()); + } + + private float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue - if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + if (disiWrapper.scorer.docID() == NO_MORE_DOCS) { continue; } totalScore += disiWrapper.scorer.score(); @@ -67,13 +124,30 @@ public float score() throws IOException { return totalScore; } + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorersPQ.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + /** * Return a DocIdSetIterator over matching documents. * @return DocIdSetIterator object */ @Override public DocIdSetIterator iterator() { - return new DisjunctionDISIApproximation(this.subScorersPQ); + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; } /** @@ -93,12 +167,28 @@ public float getMaxScore(int upTo) throws IOException { }).max(Float::compare).orElse(0.0f); } + @Override + public void setMinCompetitiveScore(float minScore) throws IOException { + if (disjunctionBlockPropagator != null) { + disjunctionBlockPropagator.setMinCompetitiveScore(minScore); + } + + for (Scorer scorer : subScorers) { + if (Objects.nonNull(scorer)) { + scorer.setMinCompetitiveScore(minScore); + } + } + } + /** * Returns the doc ID that is currently being scored. * @return document id */ @Override public int docID() { + if (subScorersPQ.size() == 0) { + return NO_MORE_DOCS; + } return subScorersPQ.top().doc; } @@ -169,4 +259,133 @@ private DisiPriorityQueue initializeSubScorersPQ() { } return subScorersPQ; } + + @Override + public Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + static class TwoPhase extends TwoPhaseIterator { + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + DisiPriorityQueue subScorers; + boolean needsScores; + + private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + super(approximation); + this.matchCost = matchCost; + this.subScorers = subScorers; + unverifiedMatches = new PriorityQueue<>(subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + this.needsScores = needsScores; + } + + DisiWrapper getSubMatches() throws IOException { + // iteration order does not matter + for (DisiWrapper w : unverifiedMatches) { + if (w.twoPhaseView.matches()) { + w.next = verifiedMatches; + verifiedMatches = w; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper w = subScorers.topList(); w != null;) { + DisiWrapper next = w.next; + + if (w.twoPhaseView == null) { + // implicitly verified, move it to verifiedMatches + w.next = verifiedMatches; + verifiedMatches = w; + + if (!needsScores) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(w); + } + w = next; + } + + if (verifiedMatches != null) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper w = unverifiedMatches.pop(); + if (w.twoPhaseView.matches()) { + w.next = null; + verifiedMatches = w; + return true; + } + } + + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + static class HybridDisjunctionDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator delegate; + final DisiPriorityQueue subIterators; + + public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) { + delegate = new DisjunctionDISIApproximation(subIterators); + this.subIterators = subIterators; + } + + @Override + public long cost() { + return delegate.cost(); + } + + @Override + public int docID() { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.docID(); + } + + @Override + public int nextDoc() throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.advance(target); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 69ee5015f..76bdd5f00 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -5,10 +5,12 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -16,6 +18,7 @@ import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; /** @@ -23,18 +26,18 @@ */ public final class HybridQueryWeight extends Weight { - private final HybridQuery queries; // The Weights for our subqueries, in 1-1 correspondence private final List weights; private final ScoreMode scoreMode; + static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; + /** * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */ public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(hybridQuery); - this.queries = hybridQuery; weights = hybridQuery.getSubQueries().stream().map(q -> { try { return searcher.createWeight(q, scoreMode, boost); @@ -65,6 +68,20 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { return MatchesUtils.fromSubMatches(mis); } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + List scorerSuppliers = new ArrayList<>(); + for (Weight w : weights) { + ScorerSupplier ss = w.scorerSupplier(context); + scorerSuppliers.add(ss); + } + + if (scorerSuppliers.isEmpty()) { + return null; + } + return new HybridScorerSupplier(scorerSuppliers, this, scoreMode); + } + /** * Create the scorer used to score our associated Query * @@ -75,19 +92,12 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { */ @Override public Scorer scorer(LeafReaderContext context) throws IOException { - List scorers = weights.stream().map(w -> { - try { - return w.scorer(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - // if there are no matches in any of the scorers (sub-queries) we need to return - // scorer as null to avoid problems with disi result iterators - if (scorers.stream().allMatch(Objects::isNull)) { + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { return null; } - return new HybridQueryScorer(this, scorers); + supplier.setTopLevelScoringClause(); + return supplier.get(Long.MAX_VALUE); } /** @@ -98,6 +108,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { + if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { + // Disallow caching large queries to not encourage users + // to build large queries + return false; + } return weights.stream().allMatch(w -> w.isCacheable(ctx)); } @@ -113,4 +128,50 @@ public boolean isCacheable(LeafReaderContext ctx) { public Explanation explain(LeafReaderContext context, int doc) throws IOException { throw new UnsupportedOperationException("Explain is not supported"); } + + @RequiredArgsConstructor + static class HybridScorerSupplier extends ScorerSupplier { + private long cost = -1; + private final List scorerSuppliers; + private final Weight weight; + private final ScoreMode scoreMode; + + @Override + public Scorer get(long leadCost) throws IOException { + List tScorers = new ArrayList<>(); + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + tScorers.add(ss.get(leadCost)); + } else { + tScorers.add(null); + } + } + return new HybridQueryScorer(weight, tScorers, scoreMode); + } + + @Override + public long cost() { + if (cost == -1) { + long cost = 0; + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + cost += ss.cost(); + } + } + this.cost = cost; + } + return cost; + } + + @Override + public void setTopLevelScoringClause() throws IOException { + for (ScorerSupplier ss : scorerSuppliers) { + // sub scorers need to be able to skip too as calls to setMinCompetitiveScore get + // propagated + if (Objects.nonNull(ss)) { + ss.setTopLevelScoringClause(); + } + } + } + }; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java new file mode 100644 index 000000000..92e1bbf7e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; + +public class HybridScorePropagator { + + private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { + try { + return s.getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).thenComparing(s -> s.iterator().cost()); + + private final Scorer[] scorers; + private final float[] maxScores; + private int leadIndex = 0; + + HybridScorePropagator(Collection scorers) throws IOException { + this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); + for (Scorer scorer : this.scorers) { + scorer.advanceShallow(0); + } + Arrays.sort(this.scorers, MAX_SCORE_COMPARATOR); + + maxScores = new float[this.scorers.length]; + for (int i = 0; i < this.scorers.length; ++i) { + maxScores[i] = this.scorers[i].getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } + } + + /** See {@link Scorer#advanceShallow(int)}. */ + int advanceShallow(int target) throws IOException { + // For scorers that are below the lead index, just propagate. + for (int i = 0; i < leadIndex; ++i) { + Scorer s = scorers[i]; + if (s.docID() < target) { + s.advanceShallow(target); + } + } + + // For scorers above the lead index, we take the minimum + // boundary. + Scorer leadScorer = scorers[leadIndex]; + int upTo = leadScorer.advanceShallow(Math.max(leadScorer.docID(), target)); + + for (int i = leadIndex + 1; i < scorers.length; ++i) { + Scorer scorer = scorers[i]; + if (scorer.docID() <= target) { + upTo = Math.min(scorer.advanceShallow(target), upTo); + } + } + + // If the maximum scoring clauses are beyond `target`, then we use their + // docID as a boundary. It helps not consider them when computing the + // maximum score and get a lower score upper bound. + for (int i = scorers.length - 1; i > leadIndex; --i) { + Scorer scorer = scorers[i]; + if (scorer.docID() > target) { + upTo = Math.min(upTo, scorer.docID() - 1); + } else { + break; + } + } + + return upTo; + } + + /** + * Set the minimum competitive score to filter out clauses that score less than this threshold. + * + * @see Scorer#setMinCompetitiveScore + */ + void setMinCompetitiveScore(float minScore) throws IOException { + // Update the lead index if necessary + while (leadIndex < maxScores.length - 1 && minScore > maxScores[leadIndex]) { + leadIndex++; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..9190bfeac 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -19,7 +19,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.query.HybridQueryScorer; @@ -47,16 +46,35 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol } @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + public LeafCollector getLeafCollector(LeafReaderContext context) { docBase = context.docBase; - return new TopScoreDocCollector.ScorerLeafCollector() { + return new LeafCollector() { HybridQueryScorer compoundQueryScorer; @Override public void setScorer(Scorable scorer) throws IOException { - super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (hybridQueryScorer != null) { + return hybridQueryScorer; + } + } + return null; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..36a9002e8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,230 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ); + } + + @Override + abstract public Collector newCollector(); + + Collector getCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + return hybridcollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List hybridTopScoreDocCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + Collector maxScoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + + @Override + public Collector newCollector() { + if (Objects.isNull(maxScoreCollector)) { + maxScoreCollector = getCollector(); + return maxScoreCollector; + } else { + Collector toReturnCollector = maxScoreCollector; + maxScoreCollector = null; + return toReturnCollector; + } + } + } + + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + + @Override + public Collector newCollector() { + return getCollector(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..5fc6017f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -4,46 +4,34 @@ */ package org.opensearch.neuralsearch.search.query; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; -import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collection; import java.util.LinkedList; import java.util.List; -import java.util.Objects; +import java.util.Map; -import org.apache.lucene.index.IndexReader; +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollector; -import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; -import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.TopDocsCollectorContext; -import org.opensearch.search.rescore.RescoreContext; -import org.opensearch.search.sort.SortAndFormats; - -import com.google.common.annotations.VisibleForTesting; +import org.opensearch.search.query.ReduceableSearchResult; import lombok.extern.log4j.Log4j2; @@ -66,15 +54,17 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } - validateQuery(searchContext, query); - return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + @VisibleForTesting + static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { @@ -103,7 +93,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) .allMatch(clause -> { return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery @@ -113,16 +103,17 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte return false; } - private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private boolean isWrappedHybridQuery(final Query query) { + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + @VisibleForTesting + protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { @@ -180,152 +171,68 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } - final QuerySearchResult queryResult = searchContext.queryResult(); + @AllArgsConstructor + public static class HybridAggregationProcessor implements AggregationProcessor { - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); + private final AggregationProcessor delegateAggsProcessor; - searcher.search(query, collector); + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + Map, CollectorManager> collectorManagersByManagerClass = context + .queryCollectorManagers(); + collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); + } } - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); + updateQueryResult(context.queryResult(), context); } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); + delegateAggsProcessor.postProcess(context); } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers() + .get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - return new TotalHits(maxTotalHits, relation); - } - - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } } } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - - private int getMaxDepthLimit(final SearchContext searchContext) { - Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); - return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 1a2e3f26e..20f00185f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -219,12 +219,7 @@ public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { when(scorer.iterator()).thenReturn(iterator(docs)); when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); - HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); - - RuntimeException runtimeException = expectThrows( - RuntimeException.class, - () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) - ); + IOException runtimeException = expectThrows(IOException.class, () -> new HybridQueryScorer(weight, Arrays.asList(scorer))); assertTrue(runtimeException.getMessage().contains("Test exception")); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java new file mode 100644 index 000000000..1c919b581 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -0,0 +1,237 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { + + static final String TEXT_FIELD_NAME = "field"; + static final String TERM_QUERY_TEXT = "keyword"; + + @SneakyThrows + public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + hybridAggregationProcessor.preProcess(searchContext); + verify(mockAggsProcessorDelegate).preProcess(any()); + + hybridAggregationProcessor.postProcess(searchContext); + verify(mockAggsProcessorDelegate).postProcess(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verify(hybridCollectorManagerSpy).reduce(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verifyNoInteractions(hybridCollectorManagerSpy); + } + + @SneakyThrows + public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query termQuery = termSubQuery.toQuery(mockQueryShardContext); + + when(searchContext.query()).thenReturn(termQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridAggregationProcessor.postProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java new file mode 100644 index 000000000..2951dd666 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoostingQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryWeight; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { + + private static final String TEXT_FIELD_NAME = "field"; + private static final String TERM_QUERY_TEXT = "keyword"; + + private static final float DELTA_FOR_ASSERTION = 0.001f; + private static final float MAX_SCORE = 0.611f; + + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testNewCollector_whenConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertNotSame(collector, secondCollector); + } + + @SneakyThrows + public void testReduce_whenMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(1); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + final Collection collectors = List.of(collector); + Object results = hybridCollectorManager.reduce(collectors); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(1, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(4, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[2].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..602c87440 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -61,6 +63,7 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; @@ -159,7 +162,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows @@ -226,7 +229,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, never()).extractHybridQuery(any(), any()); } @SneakyThrows @@ -305,17 +308,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { assertEquals(1, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(1, compoundTopDocs.size()); - TopDocs subQueryTopDocs = compoundTopDocs.get(0); - assertEquals(1, subQueryTopDocs.totalHits.value); - assertNotNull(subQueryTopDocs.scoreDocs); - assertEquals(1, subQueryTopDocs.scoreDocs.length); - ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; assertNotNull(scoreDoc); int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); assertEquals(docId1, actualDocId); @@ -403,24 +397,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes assertEquals(4, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(10, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(3, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); - - TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); - List expectedIds3 = List.of(docId1, docId2, docId3, docId4); - assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + assertEquals(4, scoreDocs.length); + List expectedIds = List.of(0, 1, 2, 3); + List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); + assertEquals(expectedIds, actualDocIds); releaseResources(directory, w, reader); } @@ -726,20 +706,10 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then assertTrue(topDocs.totalHits.value > 0); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertTrue(scoreDocs.length > 0); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(2, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; + assertTrue(scoreDoc.score > 0); + assertEquals(0, scoreDoc.doc); releaseResources(directory, w, reader); } @@ -831,6 +801,15 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + SearchContext searchContext = mock(SearchContext.class); + AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); + assertNotNull(aggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridQueryPhaseSearcher.HybridAggregationProcessor); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value);