diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 95072a24c..d6fda3740 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -234,12 +234,12 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName())); } - if (sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() <= 0 - || sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() >= 1) { + if (sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() < 0 + || sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() > 1) { throw new IllegalArgumentException( String.format( Locale.ROOT, - "[%s] %s field value must be in range (0,1)", + "[%s] %s field value must be in range (0,1]", NeuralSparseTwoPhaseParameters.NAME.getPreferredName(), NeuralSparseTwoPhaseParameters.PRUNING_RATIO.getPreferredName() ) @@ -434,7 +434,7 @@ private Map getFilteredScoreTokens(boolean aboveThreshold, float } return queryTokens.entrySet() .stream() - .filter(entry -> (aboveThreshold == (entry.getValue() > threshold))) + .filter(entry -> (aboveThreshold == (entry.getValue() >= threshold))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } diff --git a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java index 5258a4c3a..99334e8cb 100644 --- a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java +++ b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java @@ -44,6 +44,7 @@ public final class NeuralSearchSettings { public static final Setting NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION = Setting.floatSetting( "plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion", 5f, + 1f, Setting.Property.NodeScope, Setting.Property.Dynamic ); @@ -51,6 +52,8 @@ public final class NeuralSearchSettings { public static final Setting NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO = Setting.floatSetting( "plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio", 0.4f, + 0.0f, + 1.0f, Setting.Property.NodeScope, Setting.Property.Dynamic ); @@ -59,6 +62,7 @@ public final class NeuralSearchSettings { public static final Setting NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE = Setting.intSetting( "plugins.neural_search.neural_sparse.two_phase.max_window_size", 10000, + 0, Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index ceac0488d..297bd8a16 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -17,6 +18,7 @@ import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -47,9 +49,11 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.settings.NeuralSearchSettings; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; @@ -230,15 +234,14 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhaseWindowSizeExpansion_the "query_text": "string", "model_id": "string", "two_phase_settings":{ - "window_size_expansion": 0.5, + "window_size_expansion": 0.4, "pruning_ratio": 0.4, "enabled": true } } } */ - NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings() - .window_size_expansion(NeuralSparseTwoPhaseParametersTests.TEST_WINDOW_SIZE_EXPANSION); + NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().window_size_expansion(0.4f); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -261,13 +264,13 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhasePruningRate_thenBuildSu "model_id": "string", "two_phase_settings":{ "window_size_expansion": 0.5, - "pruning_ratio": 0, // or 1 + "pruning_ratio": -0.001, // or 1.001 "enabled": true } } } */ - NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().pruning_ratio(0f); + NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().pruning_ratio(-0.001f); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -280,7 +283,7 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhasePruningRate_thenBuildSu contentParser.nextToken(); expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); - parameters.pruning_ratio(1f); + parameters.pruning_ratio(1.001f); XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -880,4 +883,43 @@ public void testBuildFeatureFieldQueryFormTokens() { assertSame(booleanQuery.clauses().size(), 2); } + @SneakyThrows + public void testTokenDividedByScores_whenDefaultSettings() { + Map map = new HashMap<>(); + for (int i = 1; i < 11; i++) { + map.put(String.valueOf(i), (float) i); + } + final Supplier> tokenSupplier = () -> map; + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName("rank_features") + .queryText(QUERY_TEXT) + .neuralSparseTwoPhaseParameters(NeuralSparseTwoPhaseParameters.getDefaultSettings()) + .modelId(MODEL_ID) + .queryTokensSupplier(tokenSupplier); + QueryShardContext context = mock(QueryShardContext.class); + MappedFieldType mappedFieldType = mock(MappedFieldType.class); + when(mappedFieldType.typeName()).thenReturn("rank_features"); + when(context.fieldMapper(anyString())).thenReturn(mappedFieldType); + NeuralSparseQuery neuralSparseQuery = (NeuralSparseQuery) sparseEncodingQueryBuilder.doToQuery(context); + BooleanQuery highScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getHighScoreTokenQuery(); + BooleanQuery lowScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getLowScoreTokenQuery(); + assertNotNull(highScoreTokenQuery.clauses()); + assertNotNull(lowScoreTokenQuery.clauses()); + assertEquals(highScoreTokenQuery.clauses().size(), 7); + assertEquals(lowScoreTokenQuery.clauses().size(), 3); + sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName("rank_features") + .queryText(QUERY_TEXT) + .neuralSparseTwoPhaseParameters( + new NeuralSparseTwoPhaseParameters().enabled(true).window_size_expansion(5f).pruning_ratio(0.6f) + ) + .modelId(MODEL_ID) + .queryTokensSupplier(tokenSupplier); + neuralSparseQuery = (NeuralSparseQuery) sparseEncodingQueryBuilder.doToQuery(context); + highScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getHighScoreTokenQuery(); + lowScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getLowScoreTokenQuery(); + assertNotNull(highScoreTokenQuery.clauses()); + assertNotNull(lowScoreTokenQuery.clauses()); + assertEquals(highScoreTokenQuery.clauses().size(), 5); + assertEquals(lowScoreTokenQuery.clauses().size(), 5); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java index 005839db0..3454baede 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java @@ -4,11 +4,16 @@ */ package org.opensearch.neuralsearch.query; +import org.opensearch.index.query.ConstantScoreQueryBuilder; +import org.opensearch.index.query.DisMaxQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import static org.opensearch.neuralsearch.TestUtils.createRandomTokenWeightMap; import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -22,6 +27,12 @@ import lombok.SneakyThrows; public class NeuralSparseQueryIT extends BaseNeuralSearchIT { + + private static final String TWO_PHASE_ENABLED_SETTING_KEY = "plugins.neural_search.neural_sparse.two_phase.default_enabled"; + private static final String TWO_PHASE_WINDOW_SIZE_EXPANSION_SETTING_KEY = + "plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion"; + private static final String TWO_PHASE_PRUNE_RATIO_SETTING_KEY = "plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio"; + private static final String TWO_PHASE_MAX_WINDOW_SIZE_SETTING_KEY = "plugins.neural_search.neural_sparse.two_phase.max_window_size"; private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; private static final String TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; private static final String TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; @@ -44,6 +55,19 @@ public void setUp() throws Exception { updateClusterSettings(); } + @SneakyThrows + private void updateTwoPhaseClusterSettings(boolean enabled, float windowSizeExpansion, float ratio, int maxWindowSize) { + updateClusterSettings(TWO_PHASE_ENABLED_SETTING_KEY, enabled); + updateClusterSettings(TWO_PHASE_WINDOW_SIZE_EXPANSION_SETTING_KEY, windowSizeExpansion); + updateClusterSettings(TWO_PHASE_PRUNE_RATIO_SETTING_KEY, ratio); + updateClusterSettings(TWO_PHASE_MAX_WINDOW_SIZE_SETTING_KEY, maxWindowSize); + } + + @SneakyThrows + private NeuralSparseTwoPhaseParameters getDefaultTwoPhaseParameter(boolean enabled, float windowSizeExpansion, float ratio) { + return new NeuralSparseTwoPhaseParameters().enabled(enabled).window_size_expansion(windowSizeExpansion).pruning_ratio(ratio); + } + /** * Tests basic query with boost: * { @@ -117,6 +141,47 @@ public void testBasicQueryUsingQueryTokens() { } } + /** + * Tests basic query with boost: + * { + * "query": { + * "neural_sparse": { + * "text_sparse": { + * "query_tokens": { + * "hello": float, + * "world": float, + * "a": float, + * "b": float, + * "c": float + * }, + * "boost": 2 + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled() { + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + Map queryTokens = createRandomTokenWeightMap(TEST_TOKENS); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryTokensSupplier(() -> queryTokens) + .boost(2.0f) + .neuralSparseTwoPhaseParameters( + new NeuralSparseTwoPhaseParameters().pruning_ratio(0.8f).window_size_expansion(5.0f).enabled(true) + ); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(testRankFeaturesDoc, sparseEncodingQueryBuilder.queryTokensSupplier().get()); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, null, null); + } + } + /** * Tests rescore query: * { @@ -209,6 +274,58 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { } } + /** + * Tests bool should query with query text when two phase enabled: + * { + * "query": { + * "bool" : { + * "should": [ + * "neural_sparse": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * }, + * "neural_sparse": { + * "field2": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBooleanQuery_withMultipleSparseEncodingQueries_whenTwoPhaseEnabled() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 2.0f, 0.4f)); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 2.0f, 0.4f)); + + boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); + + Map searchResponseAsMap = search(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME, null, modelId, null); + } + } + /** * Tests bool should query with query text: * { @@ -312,4 +429,410 @@ protected void initializeIndexIfNotExist(String indexName) { } } + /** + * Tests the neuralSparseQuery when twoPhase enabled with DSL query: + * { + * "query": { + * "bool": { + * "should": [ + * { + * "neural_sparse": { + * "field": "test-sparse-encoding-1", + * "query_text": "TEST_QUERY_TEXT", + * "model_id": "dcsdcasd", + * "boost": 2.0, + * "neural_sparse_two_phase": { + * "enable": true, + * "window_size_expansion": 2.0, + * "pruning_ratio": 0.4 + * } + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 2.0f, 0.4f)); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float scoreWithoutTwoPhase = objectToFloat(firstInnerHit.get("_score")); + + sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters( + new NeuralSparseTwoPhaseParameters().enabled(true).pruning_ratio(0.3f).window_size_expansion(6.0f) + ); + searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float scoreWithTwoPhase = objectToFloat(firstInnerHit.get("_score")); + assertEquals(scoreWithTwoPhase, scoreWithoutTwoPhase, DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testUpdateTwoPhaseSettings_whenTwoPhasedSettingsOverEdge_thenFail() { + expectThrows(ResponseException.class, () -> updateTwoPhaseClusterSettings(true, 50000f, 1.4f, 10000)); + expectThrows(ResponseException.class, () -> updateTwoPhaseClusterSettings(true, 50000f, -1f, 10000)); + expectThrows(ResponseException.class, () -> updateTwoPhaseClusterSettings(true, -10f, 1.4f, 10000)); + } + + /** + * Tests neuralSparseQuery as rescoreQuery with DSL query: + * { + * "query": { + * "match_all": {} + * }, + * "rescore": { + * "query": { + * "bool": { + * "should": [ + * { + * "neural_sparse": { + * "field": "test-sparse-encoding-1", + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "boost": 2.0, + * "neural_sparse_two_phase": { + * "enable": true, + * "window_size_expansion": 4.0, + * "pruning_ratio": 0.5 + * } + * } + * } + * ] + * } + * } + * } + * } + */ + @SneakyThrows + public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 4.0f, 0.5f)); + QueryBuilder queryBuilder = new MatchAllQueryBuilder(); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, queryBuilder, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + /** + * Tests multi neuralSparseQuery in BooleanQuery with DSL query: + * { + * "query": { + * "bool": { + * "should": [ + * { + * "neural_sparse": { + * "field": "test-sparse-encoding-1", + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "boost": 2.0, + * "neural_sparse_two_phase": { + * "enable": true, + * "window_size_expansion": 4.0, + * "pruning_ratio": 0.2 + * } + * } + * }, + * { + * "neural_sparse": { + * "field": "test-sparse-encoding-1", + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "boost": 2.0, + * "neural_sparse_two_phase": { + * "enable": true, + * "window_size_expansion": 4.0, + * "pruning_ratio": 0.2 + * } + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testMultiNeuralSparseQuery_whenTwoPhase_thenGetExpectedScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 4.0f, 0.2f)); + boolQueryBuilder.should(sparseEncodingQueryBuilder); + boolQueryBuilder.should(sparseEncodingQueryBuilder); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 4 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testMultiNeuralSparseQuery_whenTwoPhaseAndFilter_thenGetExpectedScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.8f)); + boolQueryBuilder.should(sparseEncodingQueryBuilder); + boolQueryBuilder.filter(sparseEncodingQueryBuilder); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testMultiNeuralSparseQuery_whenTwoPhaseAndMultiBoolean_thenGetExpectedScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(1.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + boolQueryBuilder.should(sparseEncodingQueryBuilder1); + boolQueryBuilder.should(sparseEncodingQueryBuilder1); + BoolQueryBuilder subBoolQueryBuilder = new BoolQueryBuilder(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder3 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(3.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + subBoolQueryBuilder.should(sparseEncodingQueryBuilder2); + subBoolQueryBuilder.should(sparseEncodingQueryBuilder3); + subBoolQueryBuilder.boost(2.0f); + boolQueryBuilder.should(subBoolQueryBuilder); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 12 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testMultiNeuralSparseQuery_whenTwoPhaseAndNoLowScoreToken_thenGetExpectedScore() { + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + Map queryTokens = new HashMap<>(); + queryTokens.put("hello", 10.0f); + queryTokens.put("world", 10.0f); + queryTokens.put("a", 10.0f); + queryTokens.put("b", 10.0f); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryTokensSupplier(() -> queryTokens) + .boost(2.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(testRankFeaturesDoc, sparseEncodingQueryBuilder.queryTokensSupplier().get()); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, null, null); + } + } + + /** + * Tests constantScoreQuery with query text: + * { + * "query": { + * "constant_score" : { + * "filter": [ + * "neural_sparse": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testNeuralSParseQuery_whenTwoPhaseAndNestedInConstantScoreQuery_thenSuccess() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(1.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + ConstantScoreQueryBuilder constantScoreQueryBuilder = new ConstantScoreQueryBuilder(sparseEncodingQueryBuilder); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, constantScoreQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + assertEquals(1.0f, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + /** + * Tests disjunctionMaxQuery with query text: + * { + * "query": { + * "dis_max" : { + * "queries": [ + * { + * "neural_sparse": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * }, + * { + * "match_all":{} + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testNeuralSParseQuery_whenTwoPhaseAndNestedInDisjunctionMaxQuery_thenSuccess() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(5.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + DisMaxQueryBuilder disMaxQueryBuilder = new DisMaxQueryBuilder(); + disMaxQueryBuilder.add(sparseEncodingQueryBuilder); + disMaxQueryBuilder.add(new MatchAllQueryBuilder()); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, disMaxQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 5f * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + /** + * Tests functionScoreQuery with query text: + * { + * "query": { + * "function_score" : { + * "query":{ + * "neural_sparse": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testNeuralSParseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_thenSuccess() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareSparseEncodingModel(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(5.0f) + .neuralSparseTwoPhaseParameters(getDefaultTwoPhaseParameter(true, 5.0f, 0.6f)); + FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(sparseEncodingQueryBuilder); + functionScoreQueryBuilder.boost(2.0f); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, functionScoreQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 10f * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java index 6543dd905..edc875dee 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryTests.java @@ -38,7 +38,7 @@ public void setup() { @SneakyThrows public void testToStringMethod() { - NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); + NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); String expectedString = "NeuralSparseQuery(" + currentQuery.toString() + ',' @@ -52,7 +52,7 @@ public void testToStringMethod() { @SneakyThrows public void testRewrite_whenDifferent_thenNotSame() { IndexSearcher mockIndexSearcher = mock(IndexSearcher.class); - NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); + NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); Query rewrittenQuery = neuralSparseQuery.rewrite(mockIndexSearcher); assertNotSame(neuralSparseQuery, rewrittenQuery); assertTrue(rewrittenQuery instanceof NeuralSparseQuery); @@ -64,7 +64,8 @@ public void testRewrite_whenDifferent_thenSame() { NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery( new MatchAllDocsQuery(), new MatchAllDocsQuery(), - new MatchAllDocsQuery() + new MatchAllDocsQuery(), + 5.0f ); Query rewrittenQuery = neuralSparseQuery.rewrite(mockIndexSearcher); assertSame(neuralSparseQuery, rewrittenQuery); @@ -73,18 +74,22 @@ public void testRewrite_whenDifferent_thenSame() { @SneakyThrows public void testEqualsAndHashCode() { - NeuralSparseQuery query1 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); - NeuralSparseQuery query2 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); - NeuralSparseQuery query3 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, new MatchAllDocsQuery()); + NeuralSparseQuery query1 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); + NeuralSparseQuery query2 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); + NeuralSparseQuery query3 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, new MatchAllDocsQuery(), 5.0f); + NeuralSparseQuery query4 = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 6.0f); assertEquals(query1, query2); assertNotEquals(query1, query3); + assertNotEquals(query1, query4); assertEquals(query1.hashCode(), query2.hashCode()); assertNotEquals(query1.hashCode(), query3.hashCode()); + assertNotEquals(query1.hashCode(), query4.hashCode()); + } @SneakyThrows public void testExtractLowScoreToken_thenCurrentChanged() { - NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); + NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); assertSame(neuralSparseQuery.getCurrentQuery(), currentQuery); neuralSparseQuery.extractLowScoreToken(); assertSame(neuralSparseQuery.getCurrentQuery(), highScoreTokenQuery); @@ -92,7 +97,7 @@ public void testExtractLowScoreToken_thenCurrentChanged() { @SneakyThrows public void testVisit_thenSuccess() { - NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery); + NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f); neuralSparseQuery.visit(QueryVisitor.EMPTY_VISITOR); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParametersTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParametersTests.java index 34e04f378..3bb1c2ed5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParametersTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParametersTests.java @@ -39,20 +39,18 @@ public class NeuralSparseTwoPhaseParametersTests extends OpenSearchTestCase { public static int TEST_MAX_WINDOW_SIZE = 100; - public static float TEST_WINDOW_SIZE_EXPANSION = 0.5f; + public static float TEST_WINDOW_SIZE_EXPANSION = 6.0f; public static float TEST_PRUNING_RATIO = 0.5f; public static boolean TEST_ENABLED = false; public static NeuralSparseTwoPhaseParameters TWO_PHASE_PARAMETERS = new NeuralSparseTwoPhaseParameters().enabled(TEST_ENABLED) .pruning_ratio(TEST_PRUNING_RATIO) .window_size_expansion(TEST_WINDOW_SIZE_EXPANSION); - private Settings settings; private ClusterSettings clusterSettings; - private ClusterService clusterService; @Before public void setUpNeuralSparseTwoPhaseParameters() { - settings = Settings.builder().build(); + Settings settings = Settings.builder().build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Stream.of( @@ -63,7 +61,7 @@ public void setUpNeuralSparseTwoPhaseParameters() { ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); - clusterService = mock(ClusterService.class); + ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); NeuralSparseTwoPhaseParameters.initialize(clusterService, settings); } @@ -140,6 +138,35 @@ public void testFromXContentWithFullBodyThenSuccess() { assertEquals(TEST_WINDOW_SIZE_EXPANSION, neuralSparseTwoPhaseParameters.window_size_expansion(), 0); } + @SneakyThrows + public void testFromXContentWithFullBodyOntTimeThenDefaultRecoverSuccess() { + /* + { + "window_size_expansion": 0.5, + "pruning_ratio": 0.5, + "enabled": false + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NeuralSparseTwoPhaseParameters.WINDOW_SIZE_EXPANSION.getPreferredName(), TEST_WINDOW_SIZE_EXPANSION) + .field(NeuralSparseTwoPhaseParameters.PRUNING_RATIO.getPreferredName(), TEST_PRUNING_RATIO) + .field(NeuralSparseTwoPhaseParameters.ENABLED.getPreferredName(), TEST_ENABLED) + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters = NeuralSparseTwoPhaseParameters.parseFromXContent(contentParser); + assertEquals(TEST_ENABLED, neuralSparseTwoPhaseParameters.enabled()); + assertEquals(TEST_PRUNING_RATIO, neuralSparseTwoPhaseParameters.pruning_ratio(), 0); + assertEquals(TEST_WINDOW_SIZE_EXPANSION, neuralSparseTwoPhaseParameters.window_size_expansion(), 0); + neuralSparseTwoPhaseParameters = NeuralSparseTwoPhaseParameters.getDefaultSettings(); + assertEquals(true, neuralSparseTwoPhaseParameters.enabled()); + assertEquals(0.4f, neuralSparseTwoPhaseParameters.pruning_ratio(), 1e-6); + assertEquals(5.0f, neuralSparseTwoPhaseParameters.window_size_expansion(), 1e-6); + + } + @SneakyThrows public void testFromXContentWithEmptyBodyThenSuccess() { /* @@ -256,4 +283,22 @@ public void testIsClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport() { NeuralSearchClusterUtil.instance().initialize(clusterServiceCurrent); assertTrue(NeuralSparseTwoPhaseParameters.isClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport()); } + + @SneakyThrows + public void testUpdateSettingsEffectiveness() { + Settings updatedSettings = Settings.builder() + .put("plugins.neural_search.neural_sparse.two_phase.default_enabled", true) + .put("plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion", 10f) + .put("plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio", 0.8f) + .build(); + + clusterSettings.applySettings(updatedSettings); + NeuralSparseTwoPhaseParameters updatedParameters = NeuralSparseTwoPhaseParameters.getDefaultSettings(); + assertEquals(true, updatedParameters.enabled()); + assertEquals(10f, updatedParameters.window_size_expansion(), 0); + assertEquals(0.8f, updatedParameters.pruning_ratio(), 0); + + setUpNeuralSparseTwoPhaseParameters(); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java new file mode 100644 index 000000000..18a287ee8 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java @@ -0,0 +1,228 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.util; + +import lombok.SneakyThrows; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.DisjunctionMaxQuery; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.NeuralSparseQuery; +import org.opensearch.neuralsearch.query.NeuralSparseTwoPhaseParameters; +import org.opensearch.neuralsearch.settings.NeuralSearchSettings; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.rescore.QueryRescorer; +import org.opensearch.search.rescore.RescoreContext; +import org.opensearch.test.OpenSearchTestCase; +import org.mockito.ArgumentCaptor; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addTwoPhaseNeuralSparseQuery; + +public class NeuralSparseTwoPhaseUtilTests extends OpenSearchTestCase { + + private final SearchContext mockSearchContext = mock(SearchContext.class); + + private final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + private NeuralSparseQuery normalNeuralSparseQuery; + private final Query currentQuery = mock(Query.class); + private final Query highScoreTokenQuery = mock(Query.class); + private final Query lowScoreTokenQuery = mock(Query.class); + + protected IndexSettings createIndexSettings() { + return new IndexSettings( + IndexMetadata.builder("_index") + .settings( + Settings.builder().put("index.max_rescore_window", 10000).put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + ) + .numberOfShards(1) + .numberOfReplicas(0) + .creationDate(System.currentTimeMillis()) + .build(), + Settings.EMPTY + ); + } + + @SneakyThrows + @Before + public void testInitialize() { + normalNeuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5f); + IndexSettings indexSettings = createIndexSettings(); + when(mockSearchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + when(mockSearchContext.size()).thenReturn(10); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + Settings settings = Settings.builder().build(); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of( + NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED, + NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION, + NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO, + NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE + ) + ).collect(Collectors.toSet()); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + NeuralSparseTwoPhaseParameters.initialize(clusterService, settings); + } + + @SneakyThrows + public void testAddTwoPhaseNeuralSparseQuery_whenQuery2WeightEmpty_thenNoRescoreAdded() { + Query query = mock(Query.class); + addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + verify(mockSearchContext, never()).addRescore(any()); + } + + @SneakyThrows + public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreAdded() { + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(normalNeuralSparseQuery, mock(DoubleValuesSource.class)); + addTwoPhaseNeuralSparseQuery(functionScoreQuery, mockSearchContext); + DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(Collections.emptyList(), 1.0f); + addTwoPhaseNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext); + List subQueries = new ArrayList<>(); + List filterQueries = new ArrayList<>(); + subQueries.add(normalNeuralSparseQuery); + filterQueries.add(new MatchAllDocsQuery()); + HybridQuery hybridQuery = new HybridQuery(subQueries, filterQueries); + addTwoPhaseNeuralSparseQuery(hybridQuery, mockSearchContext); + assertEquals(normalNeuralSparseQuery.getCurrentQuery(), currentQuery); + verify(mockSearchContext, never()).addRescore(any()); + } + + @SneakyThrows + public void testAddTwoPhaseNeuralSparseQuery_whenSingleEntryInQuery2Weight_thenRescoreAdded() { + NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5.0f); + addTwoPhaseNeuralSparseQuery(neuralSparseQuery, mockSearchContext); + verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); + } + + @SneakyThrows + public void testAddTwoPhaseNeuralSparseQuery_whenCompoundBooleanQuery_thenRescoreAdded() { + NeuralSparseQuery neuralSparseQuery1 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 4f); + NeuralSparseQuery neuralSparseQuery2 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 4f); + BoostQuery boostQuery1 = new BoostQuery(neuralSparseQuery1, 2f); + BoostQuery boostQuery2 = new BoostQuery(neuralSparseQuery2, 3f); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD); + queryBuilder.add(boostQuery2, BooleanClause.Occur.SHOULD); + BooleanQuery booleanQuery = queryBuilder.build(); + addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext); + verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); + } + + @SneakyThrows + public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoosts() { + NeuralSparseQuery neuralSparseQuery1 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5f); + NeuralSparseQuery neuralSparseQuery2 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5f); + NeuralSparseQuery neuralSparseQuery3 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5f); + NeuralSparseQuery neuralSparseQuery4 = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5f); + BoostQuery boostQuery1 = new BoostQuery(neuralSparseQuery1, 2f); + BoostQuery boostQuery2 = new BoostQuery(neuralSparseQuery2, 3f); + BoostQuery boostQuery3 = new BoostQuery(neuralSparseQuery3, 4f); + BoostQuery boostQuery4 = new BoostQuery(neuralSparseQuery4, 5f); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD); + queryBuilder.add(boostQuery2, BooleanClause.Occur.MUST); + queryBuilder.add(boostQuery3, BooleanClause.Occur.FILTER); + queryBuilder.add(boostQuery4, BooleanClause.Occur.MUST_NOT); + BooleanQuery booleanQuery = queryBuilder.build(); + addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext); + ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); + verify(mockSearchContext).addRescore(rtxCaptor.capture()); + QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); + BooleanQuery rescoreQuery = (BooleanQuery) context.query(); + List clauses = rescoreQuery.clauses(); + List shouldBoosts = new ArrayList<>(); + for (BooleanClause clause : clauses) { + if (clause.getOccur() == BooleanClause.Occur.SHOULD) { + Query query = clause.getQuery(); + if (query instanceof BoostQuery) { + BoostQuery boostQuery = (BoostQuery) query; + shouldBoosts.add(boostQuery.getBoost()); + } + } + } + assertEquals(shouldBoosts.size(), 2); + assertTrue("Should contain boost 2.0", shouldBoosts.contains(2.0f)); + assertTrue("Should contain boost 3.0", shouldBoosts.contains(3.0f)); + } + + @SneakyThrows + public void testWindowSize_whenNormalConditions_thenWindowSizeIsAsSet() { + NeuralSparseQuery query = normalNeuralSparseQuery; + addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + ArgumentCaptor rescoreContextArgumentCaptor = ArgumentCaptor.forClass( + QueryRescorer.QueryRescoreContext.class + ); + verify(mockSearchContext).addRescore(rescoreContextArgumentCaptor.capture()); + assertEquals(50, rescoreContextArgumentCaptor.getValue().getWindowSize()); + } + + @SneakyThrows + public void testWindowSize_whenBoundaryConditions_thenThrowException() { + + NeuralSparseQuery query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), 5000f); + NeuralSparseQuery finalQuery1 = query; + expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery1, mockSearchContext); }); + query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), Float.MAX_VALUE); + NeuralSparseQuery finalQuery = query; + expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery, mockSearchContext); }); + } + + @SneakyThrows + public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCalculateCorrectWeight() { + QueryRescorer.QueryRescoreContext mockContext1 = mock(QueryRescorer.QueryRescoreContext.class); + QueryRescorer.QueryRescoreContext mockContext2 = mock(QueryRescorer.QueryRescoreContext.class); + when(mockContext1.queryWeight()).thenReturn(2.0f); + when(mockContext2.queryWeight()).thenReturn(3.0f); + List rescoreContextList = Arrays.asList(mockContext1, mockContext2); + when(mockSearchContext.rescore()).thenReturn(rescoreContextList); + NeuralSparseQuery query = normalNeuralSparseQuery; + addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); + verify(mockSearchContext).addRescore(rtxCaptor.capture()); + QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); + assertEquals(context.rescoreQueryWeight(), 6.0f, 0.01f); + } + + @SneakyThrows + public void testEmptyRescoreListWeight_whenRescoreListEmpty_thenDefaultWeightUsed() { + when(mockSearchContext.rescore()).thenReturn(Collections.emptyList()); + NeuralSparseQuery query = normalNeuralSparseQuery; + addTwoPhaseNeuralSparseQuery(query, mockSearchContext); + ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); + verify(mockSearchContext).addRescore(rtxCaptor.capture()); + QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); + assertEquals(context.rescoreQueryWeight(), 1.0f, 0.01f); + } + +}