diff --git a/CHANGELOG.md b/CHANGELOG.md index cac1834bc..9727bcc3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670)) - Removed stream.findFirst implementation to use more native iteration implement to improve hybrid query latencies by 35% ([#706](https://github.com/opensearch-project/neural-search/pull/706)) +- Removed map of subquery to subquery index in favor of storing index as part of disi wrapper to improve hybrid query latencies by 20% ([#711](https://github.com/opensearch-project/neural-search/pull/711)) ### Bug Fixes - Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 1da610f53..5afd43917 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -8,18 +8,15 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Objects; -import com.google.common.primitives.Ints; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TwoPhaseIterator; @@ -27,12 +24,14 @@ import lombok.Getter; import org.apache.lucene.util.PriorityQueue; +import org.opensearch.neuralsearch.search.HybridDisiWrapper; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing * order of doc id, this class fills up array of scores per sub-query for each doc id. Order in array of scores * corresponds to order of sub-queries in an input Hybrid query. */ +@Log4j2 public final class HybridQueryScorer extends Scorer { // score for each of sub-query in this hybrid query @@ -43,8 +42,6 @@ public final class HybridQueryScorer extends Scorer { private final float[] subScores; - private final Map queryToIndex; - private final DocIdSetIterator approximation; private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; private final TwoPhase twoPhase; @@ -57,7 +54,6 @@ public HybridQueryScorer(final Weight weight, final List subScorers) thr super(weight); this.subScorers = Collections.unmodifiableList(subScorers); subScores = new float[subScorers.size()]; - this.queryToIndex = mapQueryToIndex(); this.subScorersPQ = initializeSubScorersPQ(); boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; @@ -195,69 +191,43 @@ public int docID() { public float[] hybridScores() throws IOException { float[] scores = new float[subScores.length]; DisiWrapper topList = subScorersPQ.topList(); - for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { + if (topList instanceof HybridDisiWrapper == false) { + log.error( + String.format( + Locale.ROOT, + "Unexpected type of DISI wrapper, expected [%s] but found [%s]", + HybridDisiWrapper.class.getSimpleName(), + subScorersPQ.topList().getClass().getSimpleName() + ) + ); + throw new IllegalStateException( + "Unable to collect scores for one of the sub-queries, encountered an unexpected type of score iterator." + ); + } + for (HybridDisiWrapper disiWrapper = (HybridDisiWrapper) topList; disiWrapper != null; disiWrapper = + (HybridDisiWrapper) disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue Scorer scorer = disiWrapper.scorer; if (scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { continue; } - Query query = scorer.getWeight().getQuery(); - int[] indexes = queryToIndex.get(query); - // we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0" - int index = -1; - for (int idx : indexes) { - if (Float.compare(scores[idx], 0.0f) == 0) { - index = idx; - break; - } - } - if (index == -1) { - throw new IllegalStateException( - String.format( - Locale.ROOT, - "cannot set score for one of hybrid search subquery [%s] and document [%d]", - query.toString(), - scorer.docID() - ) - ); - } - scores[index] = scorer.score(); + scores[disiWrapper.getSubQueryIndex()] = scorer.score(); } return scores; } - private Map mapQueryToIndex() { - // we need list as number of identical queries is unknown - Map> queryToListOfIndexes = new HashMap<>(); - int idx = 0; - for (Scorer scorer : subScorers) { - if (scorer == null) { - idx++; - continue; - } - Query query = scorer.getWeight().getQuery(); - queryToListOfIndexes.putIfAbsent(query, new ArrayList<>()); - queryToListOfIndexes.get(query).add(idx); - idx++; - } - // convert to the int array for better performance - Map queryToIndex = new HashMap<>(); - queryToListOfIndexes.forEach((key, value) -> queryToIndex.put(key, Ints.toArray(value))); - return queryToIndex; - } - private DisiPriorityQueue initializeSubScorersPQ() { - Objects.requireNonNull(queryToIndex, "should not be null"); Objects.requireNonNull(subScorers, "should not be null"); // we need to count this way in order to include all identical sub-queries - int numOfSubQueries = queryToIndex.values().stream().map(array -> array.length).reduce(0, Integer::sum); + int numOfSubQueries = subScorers.size(); DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries); - for (Scorer scorer : subScorers) { + for (int idx = 0; idx < subScorers.size(); idx++) { + Scorer scorer = subScorers.get(idx); if (scorer == null) { continue; } - final DisiWrapper w = new DisiWrapper(scorer); - subScorersPQ.add(w); + final HybridDisiWrapper disiWrapper = new HybridDisiWrapper(scorer, idx); + subScorersPQ.add(disiWrapper); } return subScorersPQ; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridDisiWrapper.java b/src/main/java/org/opensearch/neuralsearch/search/HybridDisiWrapper.java new file mode 100644 index 000000000..7165ce055 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridDisiWrapper.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search; + +import lombok.Getter; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.Scorer; + +/** + * Wrapper for DisiWrapper, saves state of sub-queries for performance reasons + */ +@Getter +public class HybridDisiWrapper extends DisiWrapper { + // index of disi wrapper sub-query object when its part of the hybrid query + private final int subQueryIndex; + + public HybridDisiWrapper(Scorer scorer, int subQueryIndex) { + super(scorer); + this.subQueryIndex = subQueryIndex; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridDisiWrapperTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridDisiWrapperTests.java new file mode 100644 index 000000000..cd6076290 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridDisiWrapperTests.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HybridDisiWrapperTests extends OpenSearchQueryTestCase { + + public void testSubQueryIndex_whenCreateNewInstanceAndSetIndex_thenSuccessful() { + Scorer scorer = mock(Scorer.class); + DocIdSetIterator docIdSetIterator = mock(DocIdSetIterator.class); + when(scorer.iterator()).thenReturn(docIdSetIterator); + int subQueryIndex = 2; + HybridDisiWrapper hybridDisiWrapper = new HybridDisiWrapper(scorer, subQueryIndex); + assertEquals(2, hybridDisiWrapper.getSubQueryIndex()); + } +}