Skip to content

Commit

Permalink
Optimize the max score tracking in the Query Phase of Hybrid Search (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
vibrantvarun authored May 28, 2024
1 parent afd1215 commit 806042c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
- Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765))
### Bug Fixes
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class HybridTopScoreDocCollector implements Collector {
private int[] collectedHitsPerSubQuery;
private final int numOfHits;
private PriorityQueue<ScoreDoc>[] compoundScores;
@Getter
private float maxScore = 0.0f;

public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
numOfHits = numHits;
Expand Down Expand Up @@ -115,6 +117,7 @@ public void collect(int doc) throws IOException {
collectedHitsPerSubQuery[i]++;
PriorityQueue<ScoreDoc> pq = compoundScores[i];
ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
maxScore = Math.max(currentDoc.score, maxScore);
// this way we're inserting into heap and do nothing else unless we reach the capacity
// after that we pull out the lowest score element on each insert
pq.insertWithOverflow(currentDoc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
}
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
Expand Down Expand Up @@ -212,18 +211,6 @@ private TotalHits getTotalHits(
return new TotalHits(maxTotalHits, relation);
}

private float getMaxScore(final List<TopDocs> topDocs) {
if (topDocs.isEmpty()) {
return 0.0f;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
Expand Down
30 changes: 30 additions & 0 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,36 @@ public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() {
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testMaxScoreCalculation_whenMaxScoreIsTrackedAtCollectorLevel_thenSuccessful() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);
Map<String, Object> searchResponseAsMap = search(
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderNeuralThenTerm,
null,
10,
null
);

double maxScore = getMaxScore(searchResponseAsMap).get();
List<Map<String, Object>> hits = getNestedHits(searchResponseAsMap);
double maxScoreExpected = 0.0;
for (Map<String, Object> hit : hits) {
double score = (double) hit.get("_score");
maxScoreExpected = Math.max(score, maxScoreExpected);
}
assertEquals(maxScoreExpected, maxScore, 0.0000001);
}

/**
* Tests complex query with multiple nested sub-queries, where some sub-queries are same
* {
Expand Down

0 comments on commit 806042c

Please sign in to comment.