Skip to content

Commit

Permalink
Initial version, included tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 3, 2024
1 parent b97dbe8 commit d48c06c
Show file tree
Hide file tree
Showing 13 changed files with 1,173 additions and 241 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Adding two phase iterator for hybrid query ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
225 changes: 222 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -18,10 +19,15 @@
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +46,60 @@ public final class HybridQueryScorer extends Scorer {

private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScorePropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
// max
this.subScorers = Collections.unmodifiableList(subScorers);
// custom
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
// base
this.subScorersPQ = initializeSubScorersPQ();
// base
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ);
// max
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;

Check warning on line 73 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L73

Added line #L73 was not covered by tests
}
// base
boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;

Check warning on line 86 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L85-L86

Added lines #L85 - L86 were not covered by tests
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);

Check warning on line 93 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L92-L93

Added lines #L92 - L93 were not covered by tests
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);

Check warning on line 100 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L100

Added line #L100 was not covered by tests
}
return super.advanceShallow(target);

Check warning on line 102 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L102

Added line #L102 was not covered by tests
}

/**
Expand All @@ -55,25 +109,45 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOExcept
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
if (disiWrapper.scorer.docID() == NO_MORE_DOCS) {
continue;
}
totalScore += disiWrapper.scorer.score();
}
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();

Check warning on line 131 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L131

Added line #L131 was not covered by tests
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);

Check warning on line 142 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L142

Added line #L142 was not covered by tests
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +167,28 @@ public float getMaxScore(int upTo) throws IOException {
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);

Check warning on line 173 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L173

Added line #L173 was not covered by tests
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);

Check warning on line 178 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L178

Added line #L178 was not covered by tests
}
}
}

Check warning on line 181 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L180-L181

Added lines #L180 - L181 were not covered by tests

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 190 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L190

Added line #L190 was not covered by tests
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +259,133 @@ private DisiPriorityQueue initializeSubScorersPQ() {
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();

Check warning on line 265 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L265

Added line #L265 was not covered by tests
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));

Check warning on line 267 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L267

Added line #L267 was not covered by tests
}
return children;

Check warning on line 269 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L269

Added line #L269 was not covered by tests
}

static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {

Check warning on line 285 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L282-L285

Added lines #L282 - L285 were not covered by tests
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

Check warning on line 292 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L291-L292

Added lines #L291 - L292 were not covered by tests

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;

Check warning on line 299 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L298-L299

Added lines #L298 - L299 were not covered by tests
}
}
unverifiedMatches.clear();
return verifiedMatches;

Check warning on line 303 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L301-L303

Added lines #L301 - L303 were not covered by tests
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

Check warning on line 309 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L308-L309

Added lines #L308 - L309 were not covered by tests

for (DisiWrapper w = subScorers.topList(); w != null;) {
DisiWrapper next = w.next;

Check warning on line 312 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L312

Added line #L312 was not covered by tests

if (w.twoPhaseView == null) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;

Check warning on line 317 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L316-L317

Added lines #L316 - L317 were not covered by tests

if (!needsScores) {
// we can stop here
return true;

Check warning on line 321 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L321

Added line #L321 was not covered by tests
}
} else {
unverifiedMatches.add(w);

Check warning on line 324 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L324

Added line #L324 was not covered by tests
}
w = next;
}

Check warning on line 327 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L326-L327

Added lines #L326 - L327 were not covered by tests

if (verifiedMatches != null) {
return true;

Check warning on line 330 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L330

Added line #L330 was not covered by tests
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();

Check warning on line 336 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L336

Added line #L336 was not covered by tests
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;

Check warning on line 340 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L338-L340

Added lines #L338 - L340 were not covered by tests
}
}

Check warning on line 342 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L342

Added line #L342 was not covered by tests

return false;

Check warning on line 344 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L344

Added line #L344 was not covered by tests
}

@Override
public float matchCost() {
return matchCost;

Check warning on line 349 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L349

Added line #L349 was not covered by tests
}
}

static class HybridDisjunctionDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator delegate;
final DisiPriorityQueue subIterators;

public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
delegate = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return delegate.cost();

Check warning on line 364 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L364

Added line #L364 was not covered by tests
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return delegate.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 378 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L378

Added line #L378 was not covered by tests
}
return delegate.nextDoc();
}

@Override
public int advance(int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 386 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L386

Added line #L386 was not covered by tests
}
return delegate.advance(target);

Check warning on line 388 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L388

Added line #L388 was not covered by tests
}
}
}
Loading

0 comments on commit d48c06c

Please sign in to comment.