Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Bug fix for total hits counts mismatch in hybrid query #760

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
### Bug Fixes
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -80,16 +78,15 @@ private List<ScoreDoc> getCombinedScoreDocs(
final CompoundTopDocs compoundQueryTopDocs,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<Integer> sortedScores,
final int maxHits
final long maxHits
) {
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

List<ScoreDoc> scoreDocs = new ArrayList<>();
int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex;
for (int j = 0; j < maxHits && j < sortedScores.size(); j++) {
int docId = sortedScores.get(j);
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
scoreDocs.add(new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId));
}
return Arrays.stream(finalScoreDocs).collect(Collectors.toList());
return scoreDocs;
}

public Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
Expand Down Expand Up @@ -123,30 +120,16 @@ private void updateQueryTopDocsWithCombinedScores(
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final List<Integer> sortedScores
) {
// - count max number of hits among sub-queries
int maxHits = getMaxHits(topDocsPerSubQuery);
// - max number of hits will be the same which are passed from QueryPhase
long maxHits = compoundQueryTopDocs.getTotalHits().value;
// - update query search results with normalized scores
compoundQueryTopDocs.setScoreDocs(
getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits)
);
compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits));
}

/**
* Get max hits as number of unique doc ids from results of all sub-queries
* @param topDocsPerSubQuery list of topDocs objects for one shard
* @return number of unique doc ids
*/
protected int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
Set<Integer> docIds = topDocsPerSubQuery.stream()
.filter(topDocs -> Objects.nonNull(topDocs.scoreDocs))
.flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs))
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toSet());
return docIds.size();
}

private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final long maxHits) {
TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.Getter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.HitQueue;
Expand All @@ -35,7 +34,9 @@ public class HybridTopScoreDocCollector implements Collector {
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
private int[] totalHits;
@Getter
private int totalHits;
private int[] collectedHitsPerSubQuery;
private final int numOfHits;
private PriorityQueue<ScoreDoc>[] compoundScores;

Expand Down Expand Up @@ -94,23 +95,24 @@ public void collect(int doc) throws IOException {
if (Objects.isNull(compoundQueryScorer)) {
throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");
}

float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
compoundScores = new PriorityQueue[subScoresByQuery.length];
for (int i = 0; i < subScoresByQuery.length; i++) {
compoundScores[i] = new HitQueue(numOfHits, false);
}
totalHits = new int[subScoresByQuery.length];
collectedHitsPerSubQuery = new int[subScoresByQuery.length];
}
// Increment total hit count which represents unique doc found on the shard
totalHits++;
for (int i = 0; i < subScoresByQuery.length; i++) {
float score = subScoresByQuery[i];
// if score is 0.0 there is no hits for that sub-query
if (score == 0) {
continue;
}
totalHits[i]++;
collectedHitsPerSubQuery[i]++;
PriorityQueue<ScoreDoc> pq = compoundScores[i];
ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
// this way we're inserting into heap and do nothing else unless we reach the capacity
Expand All @@ -134,9 +136,17 @@ public List<TopDocs> topDocs() {
if (compoundScores == null) {
return new ArrayList<>();
}
final List<TopDocs> topDocs = IntStream.range(0, compoundScores.length)
.mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i]))
.collect(Collectors.toList());
final List<TopDocs> topDocs = new ArrayList<>();
for (int i = 0; i < compoundScores.length; i++) {
topDocs.add(
topDocsPerQuery(
0,
Math.min(collectedHitsPerSubQuery[i], compoundScores[i].size()),
compoundScores[i],
collectedHitsPerSubQuery[i]
)
);
}
return topDocs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
Expand Down Expand Up @@ -145,7 +142,10 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs);
TopDocs newTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
topDocs
);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
Expand Down Expand Up @@ -196,24 +196,19 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> top
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final boolean isSingleShard) {
private TotalHits getTotalHits(
int trackTotalHitsUpTo,
final List<TopDocs> topDocs,
final boolean isSingleShard,
final long maxTotalHits
) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}

List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();

return new TotalHits(maxTotalHits, relation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc

final List<CompoundTopDocs> queryTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new TotalHits(5, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() {
}
}

@SneakyThrows
public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);
Map<String, Object> searchResponseAsMap = search(
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderNeuralThenTerm,
null,
1,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertEquals(1, getHitCount(searchResponseAsMap));
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(3, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

/**
* Tests complex query with multiple nested sub-queries, where some sub-queries are same
* {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
*/
package org.opensearch.neuralsearch.search;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -13,6 +23,9 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -24,16 +37,6 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.index.mapper.TextFieldMapper;
Expand All @@ -50,6 +53,7 @@ public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {
private static final String TEST_QUERY_TEXT = "greeting";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final int NUM_DOCS = 4;
private static final int NUM_HITS = 1;
private static final int TOTAL_HITS_UP_TO = 1000;

private static final int DOC_ID_1 = RandomUtils.nextInt(0, 100_000);
Expand Down Expand Up @@ -493,4 +497,71 @@ public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful()
reader.close();
directory.close();
}

@SneakyThrows
public void testTotalHitsCountValidation_whenTotalHitsCollectedAtTopLevelInCollector_thenSuccessful() {
final Directory directory = newDirectory();
final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_4, FIELD_4_VALUE, ft));
w.commit();

DirectoryReader reader = DirectoryReader.open(w);

LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);

HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector(
NUM_HITS,
new HitsThresholdChecker(Integer.MAX_VALUE)
);
LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext);
assertNotNull(leafCollector);

Weight weight = mock(Weight.class);
int[] docIdsForQuery1 = new int[] { DOC_ID_1, DOC_ID_2 };
Arrays.sort(docIdsForQuery1);
int[] docIdsForQuery2 = new int[] { DOC_ID_3, DOC_ID_4 };
Arrays.sort(docIdsForQuery2);
final List<Float> scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList());
HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
weight,
Arrays.asList(
scorer(docIdsForQuery1, scores, fakeWeight(new MatchAllDocsQuery())),
scorer(docIdsForQuery2, scores, fakeWeight(new MatchAllDocsQuery()))
)
);

leafCollector.setScorer(hybridQueryScorer);
DocIdSetIterator iterator = hybridQueryScorer.iterator();
int nextDoc = iterator.nextDoc();
while (nextDoc != NO_MORE_DOCS) {
leafCollector.collect(nextDoc);
nextDoc = iterator.nextDoc();
}

List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
long totalHits = hybridTopScoreDocCollector.getTotalHits();
List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();
assertEquals(4, totalHits);
// Total unique docs on the shard will be 2 as per 1 per sub-query
assertEquals(2, maxTotalHits);
w.close();
reader.close();
directory.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ private static IndexMetadata getIndexMetadata() {
RemoteStoreEnums.PathType.NAME,
HASHED_PREFIX.name(),
RemoteStoreEnums.PathHashAlgorithm.NAME,
RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name()
RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(),
IndexMetadata.TRANSLOG_METADATA_KEY,
"false"
);
Settings idxSettings = Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
Expand Down
Loading