From 2c556d2b512e2edb07b01bd9f796e38e4d14491e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 6 May 2024 09:49:12 -0700 Subject: [PATCH] Pass empty QueryCollectorContext in case of hybrid query to improve latencies by 20% (#731) * Pass empty QueryCollectorContext in case of hybrid query Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../query/HybridQueryPhaseSearcher.java | 77 +++++++++++++++++- .../query/HybridQueryPhaseSearcherTests.java | 81 ++++++++++--------- 3 files changed, 118 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c6be493..479bf1877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x) ### Features ### Enhancements +- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731)) ### Bug Fixes - Fix multi node "no such index" error in text chunking processor ([#713](https://github.com/opensearch-project/neural-search/pull/713)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index b97134f8f..53248f88c 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -10,6 +10,8 @@ import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; @@ -19,8 +21,10 @@ import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import lombok.extern.log4j.Log4j2; @@ -36,6 +40,14 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + private final QueryPhaseSearcher defaultQueryPhaseSearcherWithEmptyCollectorContext; + private final QueryPhaseSearcher concurrentQueryPhaseSearcherWithEmptyCollectorContext; + + public HybridQueryPhaseSearcher() { + this.defaultQueryPhaseSearcherWithEmptyCollectorContext = new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext(); + this.concurrentQueryPhaseSearcherWithEmptyCollectorContext = new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext(); + } + public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, @@ -49,10 +61,17 @@ public boolean searchWith( return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); + return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } } + private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchContext) { + return searchContext.shouldUseConcurrentSearch() + ? concurrentQueryPhaseSearcherWithEmptyCollectorContext + : defaultQueryPhaseSearcherWithEmptyCollectorContext; + } + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); @@ -132,4 +151,60 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) { AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); return new HybridAggregationProcessor(coreAggProcessor); } + + /** + * Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ + @NoArgsConstructor(access = AccessLevel.PRIVATE) + final class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } + } + + /** + * Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ + @NoArgsConstructor(access = AccessLevel.PACKAGE) + final class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a938b2111..e790ffb77 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,13 +20,11 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.io.IOException; -import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; +import java.util.concurrent.ExecutorService; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -39,6 +37,8 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -58,8 +58,6 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; @@ -78,6 +76,7 @@ import lombok.SneakyThrows; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -88,13 +87,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT4 = "This is really nice place to be"; private static final String QUERY_TEXT1 = "hello"; private static final String QUERY_TEXT2 = "randomkeyword"; - private static final String QUERY_TEXT3 = "place"; private static final Index dummyIndex = new Index("dummy", "dummy"); - private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final int K = 10; - private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); - private static final UUID INDEX_UUID = UUID.randomUUID(); - private static final String TEST_INDEX = "index"; @SneakyThrows public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { @@ -306,20 +299,22 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(1, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertNotNull(scoreDoc); - int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); - assertEquals(docId1, actualDocId); - assertTrue(scoreDoc.score > 0.0f); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -340,13 +335,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes ft.setOmitNorms(random().nextBoolean()); ft.freeze(); int docId1 = RandomizedTest.randomInt(); - int docId2 = RandomizedTest.randomInt(); - int docId3 = RandomizedTest.randomInt(); - int docId4 = RandomizedTest.randomInt(); w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); w.commit(); IndexReader reader = DirectoryReader.open(w); @@ -395,18 +384,22 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(4, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - List expectedIds = List.of(0, 1, 2, 3); - List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); - assertEquals(expectedIds, actualDocIds); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -705,18 +698,22 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -979,18 +976,22 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.query()).thenReturn(query); when(searchContext.aliasFilter()).thenReturn(termFilter); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); }