Skip to content

Commit

Permalink
Initial version, included tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 3, 2024
1 parent b97dbe8 commit 46775d3
Show file tree
Hide file tree
Showing 12 changed files with 1,172 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
225 changes: 222 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -18,10 +19,15 @@
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +46,60 @@ public final class HybridQueryScorer extends Scorer {

private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScorePropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
// max
this.subScorers = Collections.unmodifiableList(subScorers);
// custom
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
// base
this.subScorersPQ = initializeSubScorersPQ();
// base
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ);
// max
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;

Check warning on line 73 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L73

Added line #L73 was not covered by tests
}
// base
boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;

Check warning on line 86 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L85-L86

Added lines #L85 - L86 were not covered by tests
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);

Check warning on line 93 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L92-L93

Added lines #L92 - L93 were not covered by tests
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);

Check warning on line 100 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L100

Added line #L100 was not covered by tests
}
return super.advanceShallow(target);

Check warning on line 102 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L102

Added line #L102 was not covered by tests
}

/**
Expand All @@ -55,25 +109,45 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOExcept
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
if (disiWrapper.scorer.docID() == NO_MORE_DOCS) {
continue;
}
totalScore += disiWrapper.scorer.score();
}
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();

Check warning on line 131 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L131

Added line #L131 was not covered by tests
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);

Check warning on line 142 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L142

Added line #L142 was not covered by tests
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +167,28 @@ public float getMaxScore(int upTo) throws IOException {
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);

Check warning on line 173 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L173

Added line #L173 was not covered by tests
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);

Check warning on line 178 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L178

Added line #L178 was not covered by tests
}
}
}

Check warning on line 181 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L180-L181

Added lines #L180 - L181 were not covered by tests

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 190 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L190

Added line #L190 was not covered by tests
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +259,133 @@ private DisiPriorityQueue initializeSubScorersPQ() {
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();

Check warning on line 265 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L265

Added line #L265 was not covered by tests
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));

Check warning on line 267 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L267

Added line #L267 was not covered by tests
}
return children;

Check warning on line 269 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L269

Added line #L269 was not covered by tests
}

static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {

Check warning on line 285 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L282-L285

Added lines #L282 - L285 were not covered by tests
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

Check warning on line 292 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L291-L292

Added lines #L291 - L292 were not covered by tests

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;

Check warning on line 299 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L298-L299

Added lines #L298 - L299 were not covered by tests
}
}
unverifiedMatches.clear();
return verifiedMatches;

Check warning on line 303 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L301-L303

Added lines #L301 - L303 were not covered by tests
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

Check warning on line 309 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L308-L309

Added lines #L308 - L309 were not covered by tests

for (DisiWrapper w = subScorers.topList(); w != null;) {
DisiWrapper next = w.next;

Check warning on line 312 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L312

Added line #L312 was not covered by tests

if (w.twoPhaseView == null) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;

Check warning on line 317 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L316-L317

Added lines #L316 - L317 were not covered by tests

if (!needsScores) {
// we can stop here
return true;

Check warning on line 321 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L321

Added line #L321 was not covered by tests
}
} else {
unverifiedMatches.add(w);

Check warning on line 324 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L324

Added line #L324 was not covered by tests
}
w = next;
}

Check warning on line 327 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L326-L327

Added lines #L326 - L327 were not covered by tests

if (verifiedMatches != null) {
return true;

Check warning on line 330 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L330

Added line #L330 was not covered by tests
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();

Check warning on line 336 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L336

Added line #L336 was not covered by tests
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;

Check warning on line 340 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L338-L340

Added lines #L338 - L340 were not covered by tests
}
}

Check warning on line 342 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L342

Added line #L342 was not covered by tests

return false;

Check warning on line 344 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L344

Added line #L344 was not covered by tests
}

@Override
public float matchCost() {
return matchCost;

Check warning on line 349 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L349

Added line #L349 was not covered by tests
}
}

static class HybridDisjunctionDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator delegate;
final DisiPriorityQueue subIterators;

public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
delegate = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return delegate.cost();

Check warning on line 364 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L364

Added line #L364 was not covered by tests
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return delegate.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 378 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L378

Added line #L378 was not covered by tests
}
return delegate.nextDoc();
}

@Override
public int advance(int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;

Check warning on line 386 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L386

Added line #L386 was not covered by tests
}
return delegate.advance(target);

Check warning on line 388 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L388

Added line #L388 was not covered by tests
}
}
}
Loading

0 comments on commit 46775d3

Please sign in to comment.