Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator #624

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624))
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -18,10 +19,13 @@
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +44,56 @@

private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
this.subScorersPQ = initializeSubScorersPQ();
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ);
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;

Check warning on line 67 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L67

Added line #L67 was not covered by tests
}

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);

Check warning on line 94 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L94

Added line #L94 was not covered by tests
}
return super.advanceShallow(target);

Check warning on line 96 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L96

Added line #L96 was not covered by tests
}

/**
Expand All @@ -55,7 +103,10 @@
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -67,13 +118,30 @@
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +161,28 @@
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);

Check warning on line 167 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L167

Added line #L167 was not covered by tests
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);

Check warning on line 172 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L172

Added line #L172 was not covered by tests
}
}
}

Check warning on line 175 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L174-L175

Added lines #L174 - L175 were not covered by tests

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return DocIdSetIterator.NO_MORE_DOCS;

Check warning on line 184 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L184

Added line #L184 was not covered by tests
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +253,139 @@
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();

Check warning on line 259 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L259

Added line #L259 was not covered by tests
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));

Check warning on line 261 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L261

Added line #L261 was not covered by tests
}
return children;

Check warning on line 263 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L263

Added line #L263 was not covered by tests
}

/**
* Object returned by Scorer.twoPhaseIterator() to provide an approximation of a DocIdSetIterator.
* After calling nextDoc() or advance(int) on the iterator returned by approximation(), you need to check matches() to confirm if the retrieved document ID is a match.
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
*/
static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

DisiWrapper getSubMatches() throws IOException {
for (DisiWrapper wrapper : unverifiedMatches) {
if (wrapper.twoPhaseView.matches()) {
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) {
DisiWrapper next = wrapper.next;

if (Objects.isNull(wrapper.twoPhaseView)) {
// implicitly verified, move it to verifiedMatches
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;

Check warning on line 314 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L313-L314

Added lines #L313 - L314 were not covered by tests

if (!needsScores) {
// we can stop here
return true;

Check warning on line 318 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L318

Added line #L318 was not covered by tests
}
} else {
unverifiedMatches.add(wrapper);
}
wrapper = next;
}

if (Objects.nonNull(verifiedMatches)) {
return true;

Check warning on line 327 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L327

Added line #L327 was not covered by tests
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper wrapper = unverifiedMatches.pop();
if (wrapper.twoPhaseView.matches()) {
wrapper.next = null;
verifiedMatches = wrapper;
return true;
}
}
return false;
}

@Override
public float matchCost() {
return matchCost;

Check warning on line 345 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L345

Added line #L345 was not covered by tests
}
}

/**
* A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports
* sub iterators that return empty results
*/
static class HybridSubqueriesDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator docIdSetIterator;
final DisiPriorityQueue subIterators;

public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) {
docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return docIdSetIterator.cost();

Check warning on line 364 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L364

Added line #L364 was not covered by tests
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return docIdSetIterator.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 378 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L378

Added line #L378 was not covered by tests
}
return docIdSetIterator.nextDoc();
}

@Override
public int advance(final int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 386 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L386

Added line #L386 was not covered by tests
}
return docIdSetIterator.advance(target);

Check warning on line 388 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L388

Added line #L388 was not covered by tests
}
}
}
Loading
Loading