From d48c06ca974207f21f2f6bf2b7b4f8d9c72a9967 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 29 Feb 2024 18:23:15 -0800 Subject: [PATCH 1/6] Initial version, included tests Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/query/HybridQuery.java | 9 +- .../query/HybridQueryBuilder.java | 3 +- .../neuralsearch/query/HybridQueryScorer.java | 225 ++++++++++++++++- .../neuralsearch/query/HybridQueryWeight.java | 87 ++++++- .../query/HybridScorePropagator.java | 91 +++++++ .../search/HybridTopScoreDocCollector.java | 28 ++- .../search/query/HybridCollectorManager.java | 230 +++++++++++++++++ .../query/HybridQueryPhaseSearcher.java | 231 +++++------------ .../query/HybridQueryScorerTests.java | 7 +- .../HybridAggregationProcessorTests.java | 237 ++++++++++++++++++ .../query/HybridCollectorManagerTests.java | 196 +++++++++++++++ .../query/HybridQueryPhaseSearcherTests.java | 69 ++--- 13 files changed, 1173 insertions(+), 241 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 528231b07..2ddf281f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 8846f6977..01d271cdd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -12,7 +12,6 @@ import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -77,12 +76,12 @@ public String toString(String field) { /** * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, * until the rewritten query is the same as the original query. - * @param reader + * @param indexSearcher * @return * @throws IOException */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (subQueries.isEmpty()) { return new MatchNoDocsQuery("empty HybridQuery"); } @@ -90,7 +89,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenSubQueries = new ArrayList<>(); for (Query subQuery : subQueries) { - Query rewrittenSub = subQuery.rewrite(reader); + Query rewrittenSub = subQuery.rewrite(indexSearcher); /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses perform better. For hybrid query we need to track progress of re-write for all sub-queries */ @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new HybridQuery(rewrittenSubQueries); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index aa4242c2e..46c087894 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List queries private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { List queries = queryBuilders.stream().map(qb -> { try { - return Rewriteable.rewrite(qb, context).toQuery(context); + return qb.rewrite(context).toQuery(context); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 5abfd0b5e..188a90209 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -18,10 +19,15 @@ import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import lombok.Getter; +import org.apache.lucene.util.PriorityQueue; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing @@ -40,12 +46,60 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; + private final DocIdSetIterator approximation; + HybridScorePropagator disjunctionBlockPropagator; + private final TwoPhase twoPhase; + public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + this(weight, subScorers, ScoreMode.TOP_SCORES); + } + + public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { super(weight); + // max this.subScorers = Collections.unmodifiableList(subScorers); + // custom subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); + // base this.subScorersPQ = initializeSubScorersPQ(); + // base + boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; + this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ); + // max + if (scoreMode == ScoreMode.TOP_SCORES) { + this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers); + } else { + this.disjunctionBlockPropagator = null; + } + // base + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : subScorersPQ) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + if (!hasApproximation) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores); + } + } + + @Override + public int advanceShallow(int target) throws IOException { + if (disjunctionBlockPropagator != null) { + return disjunctionBlockPropagator.advanceShallow(target); + } + return super.advanceShallow(target); } /** @@ -55,11 +109,14 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept */ @Override public float score() throws IOException { - DisiWrapper topList = subScorersPQ.topList(); + return score(getSubMatches()); + } + + private float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue - if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + if (disiWrapper.scorer.docID() == NO_MORE_DOCS) { continue; } totalScore += disiWrapper.scorer.score(); @@ -67,13 +124,30 @@ public float score() throws IOException { return totalScore; } + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorersPQ.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + /** * Return a DocIdSetIterator over matching documents. * @return DocIdSetIterator object */ @Override public DocIdSetIterator iterator() { - return new DisjunctionDISIApproximation(this.subScorersPQ); + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; } /** @@ -93,12 +167,28 @@ public float getMaxScore(int upTo) throws IOException { }).max(Float::compare).orElse(0.0f); } + @Override + public void setMinCompetitiveScore(float minScore) throws IOException { + if (disjunctionBlockPropagator != null) { + disjunctionBlockPropagator.setMinCompetitiveScore(minScore); + } + + for (Scorer scorer : subScorers) { + if (Objects.nonNull(scorer)) { + scorer.setMinCompetitiveScore(minScore); + } + } + } + /** * Returns the doc ID that is currently being scored. * @return document id */ @Override public int docID() { + if (subScorersPQ.size() == 0) { + return NO_MORE_DOCS; + } return subScorersPQ.top().doc; } @@ -169,4 +259,133 @@ private DisiPriorityQueue initializeSubScorersPQ() { } return subScorersPQ; } + + @Override + public Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + static class TwoPhase extends TwoPhaseIterator { + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + DisiPriorityQueue subScorers; + boolean needsScores; + + private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + super(approximation); + this.matchCost = matchCost; + this.subScorers = subScorers; + unverifiedMatches = new PriorityQueue<>(subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + this.needsScores = needsScores; + } + + DisiWrapper getSubMatches() throws IOException { + // iteration order does not matter + for (DisiWrapper w : unverifiedMatches) { + if (w.twoPhaseView.matches()) { + w.next = verifiedMatches; + verifiedMatches = w; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper w = subScorers.topList(); w != null;) { + DisiWrapper next = w.next; + + if (w.twoPhaseView == null) { + // implicitly verified, move it to verifiedMatches + w.next = verifiedMatches; + verifiedMatches = w; + + if (!needsScores) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(w); + } + w = next; + } + + if (verifiedMatches != null) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper w = unverifiedMatches.pop(); + if (w.twoPhaseView.matches()) { + w.next = null; + verifiedMatches = w; + return true; + } + } + + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + static class HybridDisjunctionDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator delegate; + final DisiPriorityQueue subIterators; + + public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) { + delegate = new DisjunctionDISIApproximation(subIterators); + this.subIterators = subIterators; + } + + @Override + public long cost() { + return delegate.cost(); + } + + @Override + public int docID() { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.docID(); + } + + @Override + public int nextDoc() throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return delegate.advance(target); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 69ee5015f..76bdd5f00 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -5,10 +5,12 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -16,6 +18,7 @@ import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; /** @@ -23,18 +26,18 @@ */ public final class HybridQueryWeight extends Weight { - private final HybridQuery queries; // The Weights for our subqueries, in 1-1 correspondence private final List weights; private final ScoreMode scoreMode; + static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; + /** * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */ public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(hybridQuery); - this.queries = hybridQuery; weights = hybridQuery.getSubQueries().stream().map(q -> { try { return searcher.createWeight(q, scoreMode, boost); @@ -65,6 +68,20 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { return MatchesUtils.fromSubMatches(mis); } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + List scorerSuppliers = new ArrayList<>(); + for (Weight w : weights) { + ScorerSupplier ss = w.scorerSupplier(context); + scorerSuppliers.add(ss); + } + + if (scorerSuppliers.isEmpty()) { + return null; + } + return new HybridScorerSupplier(scorerSuppliers, this, scoreMode); + } + /** * Create the scorer used to score our associated Query * @@ -75,19 +92,12 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { */ @Override public Scorer scorer(LeafReaderContext context) throws IOException { - List scorers = weights.stream().map(w -> { - try { - return w.scorer(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - // if there are no matches in any of the scorers (sub-queries) we need to return - // scorer as null to avoid problems with disi result iterators - if (scorers.stream().allMatch(Objects::isNull)) { + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { return null; } - return new HybridQueryScorer(this, scorers); + supplier.setTopLevelScoringClause(); + return supplier.get(Long.MAX_VALUE); } /** @@ -98,6 +108,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { + if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { + // Disallow caching large queries to not encourage users + // to build large queries + return false; + } return weights.stream().allMatch(w -> w.isCacheable(ctx)); } @@ -113,4 +128,50 @@ public boolean isCacheable(LeafReaderContext ctx) { public Explanation explain(LeafReaderContext context, int doc) throws IOException { throw new UnsupportedOperationException("Explain is not supported"); } + + @RequiredArgsConstructor + static class HybridScorerSupplier extends ScorerSupplier { + private long cost = -1; + private final List scorerSuppliers; + private final Weight weight; + private final ScoreMode scoreMode; + + @Override + public Scorer get(long leadCost) throws IOException { + List tScorers = new ArrayList<>(); + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + tScorers.add(ss.get(leadCost)); + } else { + tScorers.add(null); + } + } + return new HybridQueryScorer(weight, tScorers, scoreMode); + } + + @Override + public long cost() { + if (cost == -1) { + long cost = 0; + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + cost += ss.cost(); + } + } + this.cost = cost; + } + return cost; + } + + @Override + public void setTopLevelScoringClause() throws IOException { + for (ScorerSupplier ss : scorerSuppliers) { + // sub scorers need to be able to skip too as calls to setMinCompetitiveScore get + // propagated + if (Objects.nonNull(ss)) { + ss.setTopLevelScoringClause(); + } + } + } + }; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java new file mode 100644 index 000000000..92e1bbf7e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; + +public class HybridScorePropagator { + + private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { + try { + return s.getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).thenComparing(s -> s.iterator().cost()); + + private final Scorer[] scorers; + private final float[] maxScores; + private int leadIndex = 0; + + HybridScorePropagator(Collection scorers) throws IOException { + this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); + for (Scorer scorer : this.scorers) { + scorer.advanceShallow(0); + } + Arrays.sort(this.scorers, MAX_SCORE_COMPARATOR); + + maxScores = new float[this.scorers.length]; + for (int i = 0; i < this.scorers.length; ++i) { + maxScores[i] = this.scorers[i].getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } + } + + /** See {@link Scorer#advanceShallow(int)}. */ + int advanceShallow(int target) throws IOException { + // For scorers that are below the lead index, just propagate. + for (int i = 0; i < leadIndex; ++i) { + Scorer s = scorers[i]; + if (s.docID() < target) { + s.advanceShallow(target); + } + } + + // For scorers above the lead index, we take the minimum + // boundary. + Scorer leadScorer = scorers[leadIndex]; + int upTo = leadScorer.advanceShallow(Math.max(leadScorer.docID(), target)); + + for (int i = leadIndex + 1; i < scorers.length; ++i) { + Scorer scorer = scorers[i]; + if (scorer.docID() <= target) { + upTo = Math.min(scorer.advanceShallow(target), upTo); + } + } + + // If the maximum scoring clauses are beyond `target`, then we use their + // docID as a boundary. It helps not consider them when computing the + // maximum score and get a lower score upper bound. + for (int i = scorers.length - 1; i > leadIndex; --i) { + Scorer scorer = scorers[i]; + if (scorer.docID() > target) { + upTo = Math.min(upTo, scorer.docID() - 1); + } else { + break; + } + } + + return upTo; + } + + /** + * Set the minimum competitive score to filter out clauses that score less than this threshold. + * + * @see Scorer#setMinCompetitiveScore + */ + void setMinCompetitiveScore(float minScore) throws IOException { + // Update the lead index if necessary + while (leadIndex < maxScores.length - 1 && minScore > maxScores[leadIndex]) { + leadIndex++; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..9190bfeac 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -19,7 +19,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.query.HybridQueryScorer; @@ -47,16 +46,35 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol } @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + public LeafCollector getLeafCollector(LeafReaderContext context) { docBase = context.docBase; - return new TopScoreDocCollector.ScorerLeafCollector() { + return new LeafCollector() { HybridQueryScorer compoundQueryScorer; @Override public void setScorer(Scorable scorer) throws IOException { - super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (hybridQueryScorer != null) { + return hybridQueryScorer; + } + } + return null; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..36a9002e8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,230 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ); + } + + @Override + abstract public Collector newCollector(); + + Collector getCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + return hybridcollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List hybridTopScoreDocCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + Collector maxScoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + + @Override + public Collector newCollector() { + if (Objects.isNull(maxScoreCollector)) { + maxScoreCollector = getCollector(); + return maxScoreCollector; + } else { + Collector toReturnCollector = maxScoreCollector; + maxScoreCollector = null; + return toReturnCollector; + } + } + } + + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + + @Override + public Collector newCollector() { + return getCollector(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..5fc6017f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -4,46 +4,34 @@ */ package org.opensearch.neuralsearch.search.query; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; -import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collection; import java.util.LinkedList; import java.util.List; -import java.util.Objects; +import java.util.Map; -import org.apache.lucene.index.IndexReader; +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollector; -import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; -import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.TopDocsCollectorContext; -import org.opensearch.search.rescore.RescoreContext; -import org.opensearch.search.sort.SortAndFormats; - -import com.google.common.annotations.VisibleForTesting; +import org.opensearch.search.query.ReduceableSearchResult; import lombok.extern.log4j.Log4j2; @@ -66,15 +54,17 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } - validateQuery(searchContext, query); - return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + @VisibleForTesting + static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { @@ -103,7 +93,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) .allMatch(clause -> { return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery @@ -113,16 +103,17 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte return false; } - private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private boolean isWrappedHybridQuery(final Query query) { + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + @VisibleForTesting + protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { @@ -180,152 +171,68 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } - final QuerySearchResult queryResult = searchContext.queryResult(); + @AllArgsConstructor + public static class HybridAggregationProcessor implements AggregationProcessor { - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); + private final AggregationProcessor delegateAggsProcessor; - searcher.search(query, collector); + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + Map, CollectorManager> collectorManagersByManagerClass = context + .queryCollectorManagers(); + collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); + } } - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); + updateQueryResult(context.queryResult(), context); } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); + delegateAggsProcessor.postProcess(context); } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers() + .get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - return new TotalHits(maxTotalHits, relation); - } - - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } } } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - - private int getMaxDepthLimit(final SearchContext searchContext) { - Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); - return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 1a2e3f26e..20f00185f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -219,12 +219,7 @@ public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { when(scorer.iterator()).thenReturn(iterator(docs)); when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); - HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); - - RuntimeException runtimeException = expectThrows( - RuntimeException.class, - () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) - ); + IOException runtimeException = expectThrows(IOException.class, () -> new HybridQueryScorer(weight, Arrays.asList(scorer))); assertTrue(runtimeException.getMessage().contains("Test exception")); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java new file mode 100644 index 000000000..1c919b581 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -0,0 +1,237 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { + + static final String TEXT_FIELD_NAME = "field"; + static final String TERM_QUERY_TEXT = "keyword"; + + @SneakyThrows + public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + hybridAggregationProcessor.preProcess(searchContext); + verify(mockAggsProcessorDelegate).preProcess(any()); + + hybridAggregationProcessor.postProcess(searchContext); + verify(mockAggsProcessorDelegate).postProcess(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verify(hybridCollectorManagerSpy).reduce(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verifyNoInteractions(hybridCollectorManagerSpy); + } + + @SneakyThrows + public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = + new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query termQuery = termSubQuery.toQuery(mockQueryShardContext); + + when(searchContext.query()).thenReturn(termQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridAggregationProcessor.postProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java new file mode 100644 index 000000000..2951dd666 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoostingQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryWeight; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { + + private static final String TEXT_FIELD_NAME = "field"; + private static final String TERM_QUERY_TEXT = "keyword"; + + private static final float DELTA_FOR_ASSERTION = 0.001f; + private static final float MAX_SCORE = 0.611f; + + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testNewCollector_whenConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertNotSame(collector, secondCollector); + } + + @SneakyThrows + public void testReduce_whenMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(1); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + final Collection collectors = List.of(collector); + Object results = hybridCollectorManager.reduce(collectors); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(1, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(4, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[2].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..602c87440 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -61,6 +63,7 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; @@ -159,7 +162,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows @@ -226,7 +229,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, never()).extractHybridQuery(any(), any()); } @SneakyThrows @@ -305,17 +308,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { assertEquals(1, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(1, compoundTopDocs.size()); - TopDocs subQueryTopDocs = compoundTopDocs.get(0); - assertEquals(1, subQueryTopDocs.totalHits.value); - assertNotNull(subQueryTopDocs.scoreDocs); - assertEquals(1, subQueryTopDocs.scoreDocs.length); - ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; assertNotNull(scoreDoc); int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); assertEquals(docId1, actualDocId); @@ -403,24 +397,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes assertEquals(4, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(10, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(3, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); - - TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); - List expectedIds3 = List.of(docId1, docId2, docId3, docId4); - assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + assertEquals(4, scoreDocs.length); + List expectedIds = List.of(0, 1, 2, 3); + List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); + assertEquals(expectedIds, actualDocIds); releaseResources(directory, w, reader); } @@ -726,20 +706,10 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then assertTrue(topDocs.totalHits.value > 0); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertTrue(scoreDocs.length > 0); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(2, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; + assertTrue(scoreDoc.score > 0); + assertEquals(0, scoreDoc.doc); releaseResources(directory, w, reader); } @@ -831,6 +801,15 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + SearchContext searchContext = mock(SearchContext.class); + AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); + assertNotNull(aggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridQueryPhaseSearcher.HybridAggregationProcessor); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); From fe113a6218591beeddf55bd69b62742620dcc0a8 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Sun, 3 Mar 2024 17:14:29 -0800 Subject: [PATCH 2/6] Adding unit tests, minor refactoring Signed-off-by: Martin Gaievski --- .../neuralsearch/query/HybridQueryScorer.java | 80 ++++++------- ...> HybridScoreBlockBoundaryPropagator.java} | 14 ++- .../search/query/HybridCollectorManager.java | 18 +++ .../query/HybridQueryScorerTests.java | 94 +++++++++++++++ ...bridScoreBlockBoundaryPropagatorTests.java | 113 ++++++++++++++++++ .../HybridTopScoreDocCollectorTests.java | 109 +++++++++++++++++ .../query/HybridCollectorManagerTests.java | 25 ++-- 7 files changed, 401 insertions(+), 52 deletions(-) rename src/main/java/org/opensearch/neuralsearch/query/{HybridScorePropagator.java => HybridScoreBlockBoundaryPropagator.java} (78%) create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 188a90209..042df3fcb 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -27,8 +27,6 @@ import lombok.Getter; import org.apache.lucene.util.PriorityQueue; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing * order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores @@ -47,7 +45,7 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; private final DocIdSetIterator approximation; - HybridScorePropagator disjunctionBlockPropagator; + HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; private final TwoPhase twoPhase; public HybridQueryScorer(Weight weight, List subScorers) throws IOException { @@ -56,23 +54,19 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { super(weight); - // max this.subScorers = Collections.unmodifiableList(subScorers); - // custom subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); - // base this.subScorersPQ = initializeSubScorersPQ(); - // base boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; - this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ); - // max + + this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ); if (scoreMode == ScoreMode.TOP_SCORES) { - this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers); + this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers); } else { this.disjunctionBlockPropagator = null; } - // base + boolean hasApproximation = false; float sumMatchCost = 0; long sumApproxCost = 0; @@ -116,7 +110,7 @@ private float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue - if (disiWrapper.scorer.docID() == NO_MORE_DOCS) { + if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { continue; } totalScore += disiWrapper.scorer.score(); @@ -187,7 +181,7 @@ public void setMinCompetitiveScore(float minScore) throws IOException { @Override public int docID() { if (subScorersPQ.size() == 0) { - return NO_MORE_DOCS; + return DocIdSetIterator.NO_MORE_DOCS; } return subScorersPQ.top().doc; } @@ -269,6 +263,10 @@ public Collection getChildren() throws IOException { return children; } + /** + * Object returned by Scorer.twoPhaseIterator() to provide an approximation of a DocIdSetIterator. + * After calling nextDoc() or advance(int) on the iterator returned by approximation(), you need to check matches() to confirm if the retrieved document ID is a match. + */ static class TwoPhase extends TwoPhaseIterator { private final float matchCost; // list of verified matches on the current doc @@ -292,11 +290,10 @@ protected boolean lessThan(DisiWrapper a, DisiWrapper b) { } DisiWrapper getSubMatches() throws IOException { - // iteration order does not matter - for (DisiWrapper w : unverifiedMatches) { - if (w.twoPhaseView.matches()) { - w.next = verifiedMatches; - verifiedMatches = w; + for (DisiWrapper wrapper : unverifiedMatches) { + if (wrapper.twoPhaseView.matches()) { + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; } } unverifiedMatches.clear(); @@ -308,39 +305,38 @@ public boolean matches() throws IOException { verifiedMatches = null; unverifiedMatches.clear(); - for (DisiWrapper w = subScorers.topList(); w != null;) { - DisiWrapper next = w.next; + for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) { + DisiWrapper next = wrapper.next; - if (w.twoPhaseView == null) { + if (Objects.isNull(wrapper.twoPhaseView)) { // implicitly verified, move it to verifiedMatches - w.next = verifiedMatches; - verifiedMatches = w; + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; if (!needsScores) { // we can stop here return true; } } else { - unverifiedMatches.add(w); + unverifiedMatches.add(wrapper); } - w = next; + wrapper = next; } - if (verifiedMatches != null) { + if (Objects.nonNull(verifiedMatches)) { return true; } // verify subs that have an two-phase iterator // least-costly ones first while (unverifiedMatches.size() > 0) { - DisiWrapper w = unverifiedMatches.pop(); - if (w.twoPhaseView.matches()) { - w.next = null; - verifiedMatches = w; + DisiWrapper wrapper = unverifiedMatches.pop(); + if (wrapper.twoPhaseView.matches()) { + wrapper.next = null; + verifiedMatches = wrapper; return true; } } - return false; } @@ -350,18 +346,22 @@ public float matchCost() { } } - static class HybridDisjunctionDISIApproximation extends DocIdSetIterator { - final DocIdSetIterator delegate; + /** + * A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports + * sub iterators that return empty results + */ + static class HybridSubqueriesDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator docIdSetIterator; final DisiPriorityQueue subIterators; - public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) { - delegate = new DisjunctionDISIApproximation(subIterators); + public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) { + docIdSetIterator = new DisjunctionDISIApproximation(subIterators); this.subIterators = subIterators; } @Override public long cost() { - return delegate.cost(); + return docIdSetIterator.cost(); } @Override @@ -369,7 +369,7 @@ public int docID() { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.docID(); + return docIdSetIterator.docID(); } @Override @@ -377,15 +377,15 @@ public int nextDoc() throws IOException { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.nextDoc(); + return docIdSetIterator.nextDoc(); } @Override - public int advance(int target) throws IOException { + public int advance(final int target) throws IOException { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.advance(target); + return docIdSetIterator.advance(target); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java similarity index 78% rename from src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java rename to src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java index 92e1bbf7e..6b47a098d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java @@ -13,7 +13,16 @@ import java.util.Comparator; import java.util.Objects; -public class HybridScorePropagator { +/** + * This class functions as a utility for propagating block boundaries within disjunctions. + * In disjunctions, where a match occurs if any subclause matches, a common approach might involve returning + * the minimum block boundary across all clauses. However, this method can introduce performance challenges, + * particularly when dealing with high minimum competitive scores and clauses with low scores that no longer + * significantly contribute to the iteration process. Therefore, this class computes block boundaries solely for clauses + * with a maximum score equal to or exceeding the minimum competitive score, or for the clause with the maximum + * score if such a clause is absent. + */ +public class HybridScoreBlockBoundaryPropagator { private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { try { @@ -27,7 +36,7 @@ public class HybridScorePropagator { private final float[] maxScores; private int leadIndex = 0; - HybridScorePropagator(Collection scorers) throws IOException { + HybridScoreBlockBoundaryPropagator(final Collection scorers) throws IOException { this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); for (Scorer scorer : this.scorers) { scorer.advanceShallow(0); @@ -73,7 +82,6 @@ int advanceShallow(int target) throws IOException { break; } } - return upTo; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 36a9002e8..1d715a14c 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -34,6 +34,10 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +/** + * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. + * In most cases it will be wrapped in MultiCollectorManager. + */ @RequiredArgsConstructor public abstract class HybridCollectorManager implements CollectorManager { @@ -43,6 +47,12 @@ public abstract class HybridCollectorManager implements CollectorManager uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDoc)); + } + final int[] docs = new int[numDocs]; + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores1 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores1[i] = random().nextFloat(); + } + final float[] scores2 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores2[i] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); + + HybridQueryScorer queryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorerWithTwoPhaseIterator(docs, scores1, fakeWeight(new MatchAllDocsQuery()), maxDoc), + scorerWithTwoPhaseIterator(docs, scores2, fakeWeight(new MatchNoDocsQuery()), maxDoc) + ) + ); + + int doc = -1; + int idx = 0; + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + doc = queryScorer.iterator().nextDoc(); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f); + } + idx++; + } + } + + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { + final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return new TwoPhaseIterator(iterator) { + + @Override + public boolean matches() { + return Arrays.binarySearch(docs, iterator.docID()) >= 0; + } + + @Override + public float matchCost() { + return 10; + } + }; + } + }; + } + private Pair generateDocuments(int maxDocId) { final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); final int[] docs = new int[numDocs]; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java new file mode 100644 index 000000000..5bf0948ea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class HybridScoreBlockBoundaryPropagatorTests extends OpenSearchQueryTestCase { + + public void testAdvanceShallow_whenMinCompetitiveScoreSet_thenSuccessful() throws IOException { + Scorer scorer1 = new MockScorer(10, 0.6f); + Scorer scorer2 = new MockScorer(40, 1.5f); + Scorer scorer3 = new MockScorer(30, 2f); + Scorer scorer4 = new MockScorer(120, 4f); + + List scorers = Arrays.asList(scorer1, scorer2, scorer3, scorer4); + Collections.shuffle(scorers, random()); + HybridScoreBlockBoundaryPropagator propagator = new HybridScoreBlockBoundaryPropagator(scorers); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.1f); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.8f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.4f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.9f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(2.5f); + assertEquals(120, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(7f); + assertEquals(120, propagator.advanceShallow(0)); + } + + private static class MockWeight extends Weight { + + MockWeight() { + super(new MatchNoDocsQuery()); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } + + private static class MockScorer extends Scorer { + + final int boundary; + final float maxScore; + + MockScorer(int boundary, float maxScore) throws IOException { + super(new MockWeight()); + this.boundary = boundary; + this.maxScore = maxScore; + } + + @Override + public int docID() { + return 0; + } + + @Override + public float score() { + throw new UnsupportedOperationException(); + } + + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public void setMinCompetitiveScore(float minCompetitiveScore) {} + + @Override + public float getMaxScore(int upTo) throws IOException { + return maxScore; + } + + @Override + public int advanceShallow(int target) { + assert target <= boundary; + return boundary; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index b67a1ee05..ad5a955c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -27,12 +28,15 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.util.PriorityQueue; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; @@ -399,4 +403,109 @@ public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsChildScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + Scorer scorer = mock(Scorer.class); + Collection childrenCollectors = List.of(new Scorable.ChildScorable(hybridQueryScorer, "MUST")); + when(scorer.getChildren()).thenReturn(childrenCollectors); + leafCollector.setScorer(scorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + leafCollector.setScorer(hybridQueryScorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 2951dd666..f9d616716 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -53,10 +53,11 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEXT_FIELD_NAME = "field"; - private static final String TERM_QUERY_TEXT = "keyword"; - + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String QUERY1 = "hello"; private static final float DELTA_FOR_ASSERTION = 0.001f; - private static final float MAX_SCORE = 0.611f; @SneakyThrows public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { @@ -64,7 +65,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); when(searchContext.query()).thenReturn(hybridQuery); @@ -95,7 +96,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); when(searchContext.query()).thenReturn(hybridQuery); @@ -128,12 +129,12 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); - when(indexReader.numDocs()).thenReturn(1); + when(indexReader.numDocs()).thenReturn(3); when(indexSearcher.getIndexReader()).thenReturn(indexReader); when(searchContext.searcher()).thenReturn(indexSearcher); when(searchContext.size()).thenReturn(1); @@ -150,8 +151,14 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); ft.setOmitNorms(random().nextBoolean()); ft.freeze(); - int docId = RandomizedTest.randomInt(); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); w.commit(); IndexReader reader = DirectoryReader.open(w); From 253781c6e84e7306306e0e601d91d9b30f00ac78 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 6 Mar 2024 13:13:35 -0800 Subject: [PATCH 3/6] Adding links to java doc for TwoPhase class Signed-off-by: Martin Gaievski --- .../opensearch/neuralsearch/query/HybridQueryScorer.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 042df3fcb..44a3810e4 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -264,8 +264,11 @@ public Collection getChildren() throws IOException { } /** - * Object returned by Scorer.twoPhaseIterator() to provide an approximation of a DocIdSetIterator. - * After calling nextDoc() or advance(int) on the iterator returned by approximation(), you need to check matches() to confirm if the retrieved document ID is a match. + * Object returned by {@link Scorer#twoPhaseIterator()} to provide an approximation of a {@link DocIdSetIterator}. + * After calling {@link DocIdSetIterator#nextDoc()} or {@link DocIdSetIterator#advance(int)} on the iterator + * returned by approximation(), you need to check {@link TwoPhaseIterator#matches()} to confirm if the retrieved + * document ID is a match. Implementation inspired by identical class for + * DisjunctionScorer */ static class TwoPhase extends TwoPhaseIterator { private final float matchCost; From 67df195bcff00de0743d78f9045af1117bab3604 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 6 Mar 2024 16:24:32 -0800 Subject: [PATCH 4/6] Addressing Jacks and Navneets comments Signed-off-by: Martin Gaievski --- CHANGELOG.md | 2 +- .../query/HybridQueryBuilder.java | 2 +- .../neuralsearch/query/HybridQueryScorer.java | 6 +- .../neuralsearch/query/HybridQueryWeight.java | 9 +-- .../search/HybridTopScoreDocCollector.java | 4 + .../query/HybridAggregationProcessor.java | 77 +++++++++++++++++++ .../search/query/HybridCollectorManager.java | 14 +++- .../query/HybridQueryPhaseSearcher.java | 62 --------------- .../HybridAggregationProcessorTests.java | 12 +-- .../query/HybridQueryPhaseSearcherTests.java | 2 +- 10 files changed, 108 insertions(+), 82 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ddf281f6..8dcdc721b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624)) +- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 46c087894..60d9fd639 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -53,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder> queryToIndex; private final DocIdSetIterator approximation; - HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; + private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; private final TwoPhase twoPhase; - public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + public HybridQueryScorer(final Weight weight, final List subScorers) throws IOException { this(weight, subScorers, ScoreMode.TOP_SCORES); } - public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { + HybridQueryScorer(final Weight weight, final List subScorers, final ScoreMode scoreMode) throws IOException { super(weight); this.subScorers = Collections.unmodifiableList(subScorers); subScores = new float[subScorers.size()]; diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 76bdd5f00..facb79694 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -21,6 +21,8 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; + /** * Calculates query weights and build query scorers for hybrid query. */ @@ -31,8 +33,6 @@ public final class HybridQueryWeight extends Weight { private final ScoreMode scoreMode; - static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; - /** * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */ @@ -108,9 +108,8 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { - if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { - // Disallow caching large queries to not encourage users - // to build large queries + if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) { + // this situation should never happen, but in case it do such query will not be cached return false; } return weights.stream().allMatch(w -> w.isCacheable(ctx)); diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 9190bfeac..79b134b38 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -79,6 +80,9 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE @Override public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java new file mode 100644 index 000000000..8b584feea --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AllArgsConstructor; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; + +/** + * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom + * collector manager for hybris query (pre phase) and reducing results (post phase) + */ +@AllArgsConstructor +public class HybridAggregationProcessor implements AggregationProcessor { + + private final AggregationProcessor delegateAggsProcessor; + + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); + + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); + } + } + + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector + // managers is not called, and we have to call it manually. This is required as we format final + // result of hybrid query in {@link HybridTopScoreCollector#reduce} + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); + } + updateQueryResult(context.queryResult(), context); + } + + delegateAggsProcessor.postProcess(context); + } + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); + } + } + + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 1d715a14c..40b10c5f3 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -85,10 +85,22 @@ Collector getCollector() { return hybridcollector; } + /** + * Reduce the results from hybrid scores collector into a format specific for hybrid search query: + * - start + * - sub-query-delimiter + * - scores + * - stop + * Ignore other collectors if they are present in the context + * @param collectors collection of collectors after they has been executed and collected documents and scores + * @return search results that can be reduced be the caller + */ @Override public ReduceableSearchResult reduce(Collection collectors) { final List hybridTopScoreDocCollectors = new ArrayList<>(); - + // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper + // in case multiple collector managers are registered. We use hybrid scores collector to format scores into + // format specific for hybrid search query: start, sub-query-delimiter, scores, stop for (final Collector collector : collectors) { if (collector instanceof MultiCollectorWrapper) { for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 5fc6017f2..6461c698e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -5,17 +5,12 @@ package org.opensearch.neuralsearch.search.query; import java.io.IOException; -import java.util.Collection; import java.util.LinkedList; import java.util.List; -import java.util.Map; import com.google.common.annotations.VisibleForTesting; -import lombok.AllArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.Collector; -import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; @@ -28,10 +23,7 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; -import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; -import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.ReduceableSearchResult; import lombok.extern.log4j.Log4j2; @@ -181,58 +173,4 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) { AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); return new HybridAggregationProcessor(coreAggProcessor); } - - @AllArgsConstructor - public static class HybridAggregationProcessor implements AggregationProcessor { - - private final AggregationProcessor delegateAggsProcessor; - - @Override - public void preProcess(SearchContext context) { - delegateAggsProcessor.preProcess(context); - - if (isHybridQuery(context.query(), context)) { - // adding collector manager for hybrid query - CollectorManager collectorManager; - try { - collectorManager = HybridCollectorManager.createHybridCollectorManager(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - Map, CollectorManager> collectorManagersByManagerClass = context - .queryCollectorManagers(); - collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); - } - } - - @Override - public void postProcess(SearchContext context) { - if (isHybridQuery(context.query(), context)) { - if (!context.shouldUseConcurrentSearch()) { - reduceCollectorResults(context); - } - updateQueryResult(context.queryResult(), context); - } - - delegateAggsProcessor.postProcess(context); - } - - private void reduceCollectorResults(SearchContext context) { - CollectorManager collectorManager = context.queryCollectorManagers() - .get(HybridCollectorManager.class); - try { - final Collection collectors = List.of(collectorManager.newCollector()); - collectorManager.reduce(collectors).reduce(context.queryResult()); - } catch (IOException e) { - throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); - } - } - - private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { - boolean isSingleShard = searchContext.numberOfShards() == 1; - if (isSingleShard) { - searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); - } - } - } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 1c919b581..f44e762f0 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -49,8 +49,7 @@ public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { @SneakyThrows public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); hybridAggregationProcessor.preProcess(searchContext); @@ -63,8 +62,7 @@ public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccess @SneakyThrows public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -124,8 +122,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce @SneakyThrows public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -185,8 +182,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf @SneakyThrows public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); - HybridQueryPhaseSearcher.HybridAggregationProcessor hybridAggregationProcessor = - new HybridQueryPhaseSearcher.HybridAggregationProcessor(mockAggsProcessorDelegate); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 602c87440..2aebbb5d8 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -807,7 +807,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { SearchContext searchContext = mock(SearchContext.class); AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); assertNotNull(aggregationProcessor); - assertTrue(aggregationProcessor instanceof HybridQueryPhaseSearcher.HybridAggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); } @SneakyThrows From 763061b871bd68d819341d5d1e7812676f30b1ef Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 7 Mar 2024 09:56:47 -0800 Subject: [PATCH 5/6] Adding logs and code refs to code comments Signed-off-by: Martin Gaievski --- .../search/HybridTopScoreDocCollector.java | 15 ++++++++++++++- .../search/query/HybridAggregationProcessor.java | 8 +++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 79b134b38..4418841f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -56,9 +56,15 @@ public LeafCollector getLeafCollector(LeafReaderContext context) { @Override public void setScorer(Scorable scorer) throws IOException { if (scorer instanceof HybridQueryScorer) { + log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores"); compoundQueryScorer = (HybridQueryScorer) scorer; } else { compoundQueryScorer = getHybridQueryScorer(scorer); + if (Objects.isNull(compoundQueryScorer)) { + log.error( + String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer) + ); + } } } @@ -71,7 +77,14 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE } for (Scorable.ChildScorable childScorable : scorer.getChildren()) { HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); - if (hybridQueryScorer != null) { + if (Objects.nonNull(hybridQueryScorer)) { + log.debug( + String.format( + Locale.ROOT, + "found hybrid query scorer, it's child of scorer %s", + childScorable.child.getClass().getSimpleName() + ) + ); return hybridQueryScorer; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index 8b584feea..42c27821f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -47,8 +47,14 @@ public void preProcess(SearchContext context) { public void postProcess(SearchContext context) { if (isHybridQuery(context.query(), context)) { // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector - // managers is not called, and we have to call it manually. This is required as we format final + // managers is not called + // (https://github.com/opensearch-project/OpenSearch/blob/2.12/server/src/main/java/org/opensearch/search/query/QueryPhase.java#L333-L373) + // and we have to call it manually. This is required as we format final // result of hybrid query in {@link HybridTopScoreCollector#reduce} + // when concurrent search is enabled then reduce method is called as part of the search {@see + // ConcurrentQueryPhaseSearcher#searchWithCollectorManager} + // corresponding call in Lucene + // https://github.com/apache/lucene/blob/branch_9_10/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java#L700 if (!context.shouldUseConcurrentSearch()) { reduceCollectorResults(context); } From 65553e7864352abdf4c92e339c2a22e5d5b91b44 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 8 Mar 2024 13:59:04 -0800 Subject: [PATCH 6/6] Refactor non-concurrent collector manager Signed-off-by: Martin Gaievski --- .../query/HybridAggregationProcessor.java | 9 +++---- .../search/query/HybridCollectorManager.java | 27 +++++++------------ .../query/HybridCollectorManagerTests.java | 4 +-- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index 42c27821f..4e9070748 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -6,6 +6,7 @@ import lombok.AllArgsConstructor; import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationInitializationException; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryPhaseExecutionException; @@ -13,7 +14,6 @@ import org.opensearch.search.query.ReduceableSearchResult; import java.io.IOException; -import java.util.Collection; import java.util.List; import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; @@ -36,8 +36,8 @@ public void preProcess(SearchContext context) { CollectorManager collectorManager; try { collectorManager = HybridCollectorManager.createHybridCollectorManager(context); - } catch (IOException e) { - throw new RuntimeException(e); + } catch (IOException exception) { + throw new AggregationInitializationException("could not initialize hybrid aggregation processor", exception); } context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); } @@ -67,8 +67,7 @@ public void postProcess(SearchContext context) { private void reduceCollectorResults(SearchContext context) { CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); try { - final Collection collectors = List.of(collectorManager.newCollector()); - collectorManager.reduce(collectors).reduce(context.queryResult()); + collectorManager.reduce(List.of()).reduce(context.queryResult()); } catch (IOException e) { throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 40b10c5f3..a5de898ab 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -78,9 +78,7 @@ public static CollectorManager createHybridCollectorManager(final SearchContext } @Override - abstract public Collector newCollector(); - - Collector getCollector() { + public Collector newCollector() { Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); return hybridcollector; } @@ -211,7 +209,7 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats * use saved state of collector */ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { - Collector maxScoreCollector; + private final Collector scoreCollector; public HybridCollectorNonConcurrentManager( int numHits, @@ -221,18 +219,18 @@ public HybridCollectorNonConcurrentManager( SortAndFormats sortAndFormats ) { super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @Override public Collector newCollector() { - if (Objects.isNull(maxScoreCollector)) { - maxScoreCollector = getCollector(); - return maxScoreCollector; - } else { - Collector toReturnCollector = maxScoreCollector; - maxScoreCollector = null; - return toReturnCollector; - } + return scoreCollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + assert collectors.isEmpty() : "reduce on HybridCollectorNonConcurrentManager called with non-empty collectors"; + return super.reduce(List.of(scoreCollector)); } } @@ -251,10 +249,5 @@ public HybridCollectorConcurrentSearchManager( ) { super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); } - - @Override - public Collector newCollector() { - return getCollector(); - } } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index f9d616716..65d6f3d8a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -39,7 +39,6 @@ import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -175,8 +174,7 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); leafCollector.finish(); - final Collection collectors = List.of(collector); - Object results = hybridCollectorManager.reduce(collectors); + Object results = hybridCollectorManager.reduce(List.of()); assertNotNull(results); ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results);