From 901ae891ed52c0efa65d830301678d42e9a431ce Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 24 May 2024 13:40:14 -0700 Subject: [PATCH] Bug fix for total hits counts mismatch in hybrid query (#757) (cherry picked from commit 70d0975305365fb8264ccbce05655f717a467f2d) --- CHANGELOG.md | 1 + .../processor/combination/ScoreCombiner.java | 31 ++----- .../search/HybridTopScoreDocCollector.java | 28 ++++-- .../search/query/HybridCollectorManager.java | 25 ++--- .../ScoreCombinationTechniqueTests.java | 2 +- .../neuralsearch/query/HybridQueryIT.java | 29 ++++++ .../HybridTopScoreDocCollectorTests.java | 91 +++++++++++++++++-- .../query/HybridQueryPhaseSearcherTests.java | 4 +- 8 files changed, 151 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6c0aea38..0fe4b36ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733)) - Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746)) ### Bug Fixes +- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 278d2fdfc..09d9e83f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -5,12 +5,10 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; @@ -80,16 +78,15 @@ private List getCombinedScoreDocs( final CompoundTopDocs compoundQueryTopDocs, final Map combinedNormalizedScoresByDocId, final List sortedScores, - final int maxHits + final long maxHits ) { - ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; - + List scoreDocs = new ArrayList<>(); int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex; for (int j = 0; j < maxHits && j < sortedScores.size(); j++) { int docId = sortedScores.get(j); - finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); + scoreDocs.add(new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId)); } - return Arrays.stream(finalScoreDocs).collect(Collectors.toList()); + return scoreDocs; } public Map getNormalizedScoresPerDocument(final List topDocsPerSubQuery) { @@ -123,8 +120,8 @@ private void updateQueryTopDocsWithCombinedScores( final Map combinedNormalizedScoresByDocId, final List sortedScores ) { - // - count max number of hits among sub-queries - int maxHits = getMaxHits(topDocsPerSubQuery); + // - max number of hits will be the same which are passed from QueryPhase + long maxHits = compoundQueryTopDocs.getTotalHits().value; // - update query search results with normalized scores compoundQueryTopDocs.setScoreDocs( getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits) @@ -132,21 +129,7 @@ private void updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits)); } - /** - * Get max hits as number of unique doc ids from results of all sub-queries - * @param topDocsPerSubQuery list of topDocs objects for one shard - * @return number of unique doc ids - */ - protected int getMaxHits(final List topDocsPerSubQuery) { - Set docIds = topDocsPerSubQuery.stream() - .filter(topDocs -> Objects.nonNull(topDocs.scoreDocs)) - .flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs)) - .map(scoreDoc -> scoreDoc.doc) - .collect(Collectors.toSet()); - return docIds.size(); - } - - private TotalHits getTotalHits(final List topDocsPerSubQuery, int maxHits) { + private TotalHits getTotalHits(final List topDocsPerSubQuery, final long maxHits) { TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO; if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) { totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index a9068af4b..308756909 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -9,9 +9,8 @@ import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Collector; import org.apache.lucene.search.HitQueue; @@ -35,7 +34,9 @@ public class HybridTopScoreDocCollector implements Collector { private int docBase; private final HitsThresholdChecker hitsThresholdChecker; private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; - private int[] totalHits; + @Getter + private int totalHits; + private int[] collectedHitsPerSubQuery; private final int numOfHits; private PriorityQueue[] compoundScores; @@ -94,7 +95,6 @@ public void collect(int doc) throws IOException { if (Objects.isNull(compoundQueryScorer)) { throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); } - float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { @@ -102,15 +102,17 @@ public void collect(int doc) throws IOException { for (int i = 0; i < subScoresByQuery.length; i++) { compoundScores[i] = new HitQueue(numOfHits, false); } - totalHits = new int[subScoresByQuery.length]; + collectedHitsPerSubQuery = new int[subScoresByQuery.length]; } + // Increment total hit count which represents unique doc found on the shard + totalHits++; for (int i = 0; i < subScoresByQuery.length; i++) { float score = subScoresByQuery[i]; // if score is 0.0 there is no hits for that sub-query if (score == 0) { continue; } - totalHits[i]++; + collectedHitsPerSubQuery[i]++; PriorityQueue pq = compoundScores[i]; ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score); // this way we're inserting into heap and do nothing else unless we reach the capacity @@ -134,9 +136,17 @@ public List topDocs() { if (compoundScores == null) { return new ArrayList<>(); } - final List topDocs = IntStream.range(0, compoundScores.length) - .mapToObj(i -> topDocsPerQuery(0, Math.min(totalHits[i], compoundScores[i].size()), compoundScores[i], totalHits[i])) - .collect(Collectors.toList()); + final List topDocs = new ArrayList<>(); + for (int i = 0; i < compoundScores.length; i++) { + topDocs.add( + topDocsPerQuery( + 0, + Math.min(collectedHitsPerSubQuery[i], compoundScores[i].size()), + compoundScores[i], + collectedHitsPerSubQuery[i] + ) + ); + } return topDocs; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index e9d97c3b3..120cd1428 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -31,11 +31,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; @@ -145,7 +142,10 @@ public ReduceableSearchResult reduce(Collection collectors) { .findFirst() .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); List topDocs = hybridTopScoreDocCollector.topDocs(); - TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + TopDocs newTopDocs = getNewTopDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()), + topDocs + ); float maxScore = getMaxScore(topDocs); TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; @@ -196,7 +196,12 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List top return new TopDocs(totalHits, scoreDocs); } - private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + private TotalHits getTotalHits( + int trackTotalHitsUpTo, + final List topDocs, + final boolean isSingleShard, + final long maxTotalHits + ) { final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO; @@ -204,16 +209,6 @@ private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDo return new TotalHits(0, relation); } - List scoreDocs = topDocs.stream() - .map(topdDoc -> topdDoc.scoreDocs) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - Set uniqueDocIds = new HashSet<>(); - for (ScoreDoc[] scoreDocsArray : scoreDocs) { - uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); - } - long maxTotalHits = uniqueDocIds.size(); - return new TotalHits(maxTotalHits, relation); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index da9b34f22..c97abe1a4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -27,7 +27,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc final List queryTopDocs = List.of( new CompoundTopDocs( - new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new TotalHits(5, TotalHits.Relation.EQUAL_TO), List.of( new TopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index be6942232..15e941ff2 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -177,6 +177,35 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { } } + @SneakyThrows + public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + Map searchResponseAsMap = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 1, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(1, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + /** * Tests complex query with multiple nested sub-queries, where some sub-queries are same * { diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index 96d32f7d5..351ec680c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -4,6 +4,16 @@ */ package org.opensearch.neuralsearch.search; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -13,6 +23,9 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.HashSet; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -24,16 +37,6 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.LeafCollector; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Scorable; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.index.mapper.TextFieldMapper; @@ -50,6 +53,7 @@ public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase { private static final String TEST_QUERY_TEXT = "greeting"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final int NUM_DOCS = 4; + private static final int NUM_HITS = 1; private static final int TOTAL_HITS_UP_TO = 1000; private static final int DOC_ID_1 = RandomUtils.nextInt(0, 100_000); @@ -493,4 +497,71 @@ public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testTotalHitsCountValidation_whenTotalHitsCollectedAtTopLevelInCollector_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_4, FIELD_4_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_HITS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + int[] docIdsForQuery1 = new int[] { DOC_ID_1, DOC_ID_2 }; + Arrays.sort(docIdsForQuery1); + int[] docIdsForQuery2 = new int[] { DOC_ID_3, DOC_ID_4 }; + Arrays.sort(docIdsForQuery2); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorer(docIdsForQuery1, scores, fakeWeight(new MatchAllDocsQuery())), + scorer(docIdsForQuery2, scores, fakeWeight(new MatchAllDocsQuery())) + ) + ); + + leafCollector.setScorer(hybridQueryScorer); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + int nextDoc = iterator.nextDoc(); + while (nextDoc != NO_MORE_DOCS) { + leafCollector.collect(nextDoc); + nextDoc = iterator.nextDoc(); + } + + List topDocs = hybridTopScoreDocCollector.topDocs(); + long totalHits = hybridTopScoreDocCollector.getTotalHits(); + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + assertEquals(4, totalHits); + // Total unique docs on the shard will be 2 as per 1 per sub-query + assertEquals(2, maxTotalHits); + w.close(); + reader.close(); + directory.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index ff9616637..ed637f2a0 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -1025,7 +1025,9 @@ private static IndexMetadata getIndexMetadata() { RemoteStoreEnums.PathType.NAME, HASHED_PREFIX.name(), RemoteStoreEnums.PathHashAlgorithm.NAME, - RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name() + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" ); Settings idxSettings = Settings.builder() .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)