Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator #624

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand All @@ -54,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private String fieldName;

private static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
super(in);
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -18,10 +19,13 @@
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +44,56 @@

private final Map<Query, List<Integer>> queryToIndex;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
private final DocIdSetIterator approximation;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
this.subScorersPQ = initializeSubScorersPQ();
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ);
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;

Check warning on line 67 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L67

Added line #L67 was not covered by tests
}

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);

Check warning on line 94 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L94

Added line #L94 was not covered by tests
}
return super.advanceShallow(target);

Check warning on line 96 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L96

Added line #L96 was not covered by tests
}

/**
Expand All @@ -55,7 +103,10 @@
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -67,13 +118,30 @@
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +161,28 @@
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);

Check warning on line 167 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L167

Added line #L167 was not covered by tests
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);

Check warning on line 172 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L172

Added line #L172 was not covered by tests
}
}
}

Check warning on line 175 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L174-L175

Added lines #L174 - L175 were not covered by tests

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return DocIdSetIterator.NO_MORE_DOCS;

Check warning on line 184 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L184

Added line #L184 was not covered by tests
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +253,142 @@
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();

Check warning on line 259 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L259

Added line #L259 was not covered by tests
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));

Check warning on line 261 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L261

Added line #L261 was not covered by tests
}
return children;

Check warning on line 263 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L263

Added line #L263 was not covered by tests
}

/**
* Object returned by {@link Scorer#twoPhaseIterator()} to provide an approximation of a {@link DocIdSetIterator}.
* After calling {@link DocIdSetIterator#nextDoc()} or {@link DocIdSetIterator#advance(int)} on the iterator
* returned by approximation(), you need to check {@link TwoPhaseIterator#matches()} to confirm if the retrieved
* document ID is a match. Implementation inspired by identical class for
* <a href="https://github.com/apache/lucene/blob/branch_9_10/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java">DisjunctionScorer</a>
*/
static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

DisiWrapper getSubMatches() throws IOException {
for (DisiWrapper wrapper : unverifiedMatches) {
if (wrapper.twoPhaseView.matches()) {
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) {
DisiWrapper next = wrapper.next;

if (Objects.isNull(wrapper.twoPhaseView)) {
// implicitly verified, move it to verifiedMatches
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;

Check warning on line 317 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L316-L317

Added lines #L316 - L317 were not covered by tests

if (!needsScores) {
// we can stop here
return true;

Check warning on line 321 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L321

Added line #L321 was not covered by tests
}
} else {
unverifiedMatches.add(wrapper);
}
wrapper = next;
}

if (Objects.nonNull(verifiedMatches)) {
return true;

Check warning on line 330 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L330

Added line #L330 was not covered by tests
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper wrapper = unverifiedMatches.pop();
if (wrapper.twoPhaseView.matches()) {
wrapper.next = null;
verifiedMatches = wrapper;
return true;
}
}
return false;
}

@Override
public float matchCost() {
return matchCost;

Check warning on line 348 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L348

Added line #L348 was not covered by tests
}
}

/**
* A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports
* sub iterators that return empty results
*/
static class HybridSubqueriesDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator docIdSetIterator;
final DisiPriorityQueue subIterators;

public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) {
docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return docIdSetIterator.cost();

Check warning on line 367 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L367

Added line #L367 was not covered by tests
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return docIdSetIterator.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 381 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L381

Added line #L381 was not covered by tests
}
return docIdSetIterator.nextDoc();
}

@Override
public int advance(final int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 389 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L389

Added line #L389 was not covered by tests
}
return docIdSetIterator.advance(target);

Check warning on line 391 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L391

Added line #L391 was not covered by tests
}
}
}
Loading
Loading