diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java index 5d41c267a..ef484d984 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java @@ -18,8 +18,8 @@ import org.apache.lucene.search.Weight; /** - * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual - * scores for each sub-query. + * Implementation of Query interface for type NeuralSparseQuery when TwoPhaseNeuralSparse Enabled. + * Initialized, it currentQuery include all tokenQuery. After */ @AllArgsConstructor @Getter @@ -86,7 +86,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return currentQuery.createWeight(searcher, scoreMode, boost); } - public void extractLowScoreToken() { + public void setCurrentQueryToHighScoreTokenQuery() { this.currentQuery = highScoreTokenQuery; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 853b17f91..87d8ec15a 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -25,7 +25,7 @@ import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addTwoPhaseNeuralSparseQuery; +import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addSecondPhaseRescoreContextFromValidNeuralSparseQuery; import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter; import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs; import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; @@ -45,7 +45,7 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - addTwoPhaseNeuralSparseQuery(query, searchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, searchContext); if (!isHybridQuery(query, searchContext)) { validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java index 173385a9d..c04d0fb9f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java @@ -22,6 +22,10 @@ import static java.lang.Integer.min; import static org.opensearch.index.IndexSettings.MAX_RESCORE_WINDOW_SETTING; +/** + * Util class for do two phase preprocess for the NeuralSparseQuery. + * Include adding the second phase query to searchContext and set the currentQuery to highScoreTokenQuery. + */ public class NeuralSparseTwoPhaseUtil { private static float populateQueryWeightsMapAndGetWindowSizeExpansion( @@ -48,7 +52,7 @@ private static float populateQueryWeightsMapAndGetWindowSizeExpansion( } } else if (query instanceof NeuralSparseQuery) { query2Weight.put(((NeuralSparseQuery) query).getLowScoreTokenQuery(), weight); - ((NeuralSparseQuery) query).extractLowScoreToken(); + ((NeuralSparseQuery) query).setCurrentQueryToHighScoreTokenQuery(); windoSizeExpansion = max(windoSizeExpansion, ((NeuralSparseQuery) query).getRescoreWindowSizeExpansion()); } // ToDo Support for other compound query. @@ -68,7 +72,12 @@ private static Query getNestedTwoPhaseQuery(Map query2weight) { return builder.build(); } - public static void addTwoPhaseNeuralSparseQuery(final Query query, SearchContext searchContext) { + /** + * + * @param query The whole query include neuralSparseQuery to executed. + * @param searchContext The searchContext with this query. + */ + public static void addSecondPhaseRescoreContextFromValidNeuralSparseQuery(final Query query, SearchContext searchContext) { Map query2weight = new HashMap<>(); float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f); Query twoPhaseQuery; diff --git a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java index 99334e8cb..3ce1aff3e 100644 --- a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java +++ b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java @@ -34,6 +34,9 @@ public final class NeuralSearchSettings { Setting.Property.NodeScope ); + /** + * Use this setting to manage if a neuralSparseQuery build a two-phase query of not. + */ public static final Setting NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED = Setting.boolSetting( "plugins.neural_search.neural_sparse.two_phase.default_enabled", true, @@ -41,6 +44,9 @@ public final class NeuralSearchSettings { Setting.Property.Dynamic ); + /** + * Control the number of TopDocs rescored by the second phase of NeuralSparseQuery. + */ public static final Setting NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION = Setting.floatSetting( "plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion", 5f, @@ -49,6 +55,9 @@ public final class NeuralSearchSettings { Setting.Property.Dynamic ); + /** + * Control the token score threshold to splitting the NeuralSparseQuery. + */ public static final Setting NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO = Setting.floatSetting( "plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio", 0.4f, @@ -58,7 +67,10 @@ public final class NeuralSearchSettings { Setting.Property.Dynamic ); - // the default value is consistent with core settings MAX_RESCORE_WINDOW_SETTING + /** + * Control the max rescore windows size of the second phase of NeuralSparseQuery. + * The default value is consistent with core settings MAX_RESCORE_WINDOW_SETTING. + */ public static final Setting NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE = Setting.intSetting( "plugins.neural_search.neural_sparse.two_phase.max_window_size", 10000, diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java index edc875dee..fae2bc5c1 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java @@ -91,7 +91,7 @@ public void testEqualsAndHashCode() { public void testExtractLowScoreToken_thenCurrentChanged() { NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); assertSame(neuralSparseQuery.getCurrentQuery(), currentQuery); - neuralSparseQuery.extractLowScoreToken(); + neuralSparseQuery.setCurrentQueryToHighScoreTokenQuery(); assertSame(neuralSparseQuery.getCurrentQuery(), highScoreTokenQuery); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java index 18a287ee8..cde4665df 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addTwoPhaseNeuralSparseQuery; +import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addSecondPhaseRescoreContextFromValidNeuralSparseQuery; public class NeuralSparseTwoPhaseUtilTests extends OpenSearchTestCase { @@ -98,22 +98,22 @@ public void testInitialize() { @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenQuery2WeightEmpty_thenNoRescoreAdded() { Query query = mock(Query.class); - addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext); verify(mockSearchContext, never()).addRescore(any()); } @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreAdded() { FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(normalNeuralSparseQuery, mock(DoubleValuesSource.class)); - addTwoPhaseNeuralSparseQuery(functionScoreQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(functionScoreQuery, mockSearchContext); DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(Collections.emptyList(), 1.0f); - addTwoPhaseNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext); List subQueries = new ArrayList<>(); List filterQueries = new ArrayList<>(); subQueries.add(normalNeuralSparseQuery); filterQueries.add(new MatchAllDocsQuery()); HybridQuery hybridQuery = new HybridQuery(subQueries, filterQueries); - addTwoPhaseNeuralSparseQuery(hybridQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(hybridQuery, mockSearchContext); assertEquals(normalNeuralSparseQuery.getCurrentQuery(), currentQuery); verify(mockSearchContext, never()).addRescore(any()); } @@ -121,7 +121,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreA @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenSingleEntryInQuery2Weight_thenRescoreAdded() { NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5.0f); - addTwoPhaseNeuralSparseQuery(neuralSparseQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(neuralSparseQuery, mockSearchContext); verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); } @@ -135,7 +135,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenCompoundBooleanQuery_thenRescor queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD); queryBuilder.add(boostQuery2, BooleanClause.Occur.SHOULD); BooleanQuery booleanQuery = queryBuilder.build(); - addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(booleanQuery, mockSearchContext); verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); } @@ -155,7 +155,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo queryBuilder.add(boostQuery3, BooleanClause.Occur.FILTER); queryBuilder.add(boostQuery4, BooleanClause.Occur.MUST_NOT); BooleanQuery booleanQuery = queryBuilder.build(); - addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(booleanQuery, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); @@ -179,7 +179,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo @SneakyThrows public void testWindowSize_whenNormalConditions_thenWindowSizeIsAsSet() { NeuralSparseQuery query = normalNeuralSparseQuery; - addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rescoreContextArgumentCaptor = ArgumentCaptor.forClass( QueryRescorer.QueryRescoreContext.class ); @@ -192,10 +192,10 @@ public void testWindowSize_whenBoundaryConditions_thenThrowException() { NeuralSparseQuery query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), 5000f); NeuralSparseQuery finalQuery1 = query; - expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery1, mockSearchContext); }); + expectThrows(IllegalArgumentException.class, () -> { addSecondPhaseRescoreContextFromValidNeuralSparseQuery(finalQuery1, mockSearchContext); }); query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), Float.MAX_VALUE); NeuralSparseQuery finalQuery = query; - expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery, mockSearchContext); }); + expectThrows(IllegalArgumentException.class, () -> { addSecondPhaseRescoreContextFromValidNeuralSparseQuery(finalQuery, mockSearchContext); }); } @SneakyThrows @@ -207,7 +207,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal List rescoreContextList = Arrays.asList(mockContext1, mockContext2); when(mockSearchContext.rescore()).thenReturn(rescoreContextList); NeuralSparseQuery query = normalNeuralSparseQuery; - addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); @@ -218,7 +218,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal public void testEmptyRescoreListWeight_whenRescoreListEmpty_thenDefaultWeightUsed() { when(mockSearchContext.rescore()).thenReturn(Collections.emptyList()); NeuralSparseQuery query = normalNeuralSparseQuery; - addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();