diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 2b02d9634..e8745b042 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -23,6 +23,7 @@ import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; /** @@ -81,15 +82,18 @@ protected NeuralSparseTwoPhaseProcessor( * @return request the search request that add the two-phase rescore query of neural sparse query. */ @Override - public SearchRequest processRequest(SearchRequest request) throws Exception { + public SearchRequest processRequest(final SearchRequest request) throws Exception { if (!enabled || ratio == 0f) { return request; } QueryBuilder queryBuilder = request.source().query(); + // Collect the nested NeuralSparseQueryBuilder in the whole query. Multimap queryBuilderMap = ArrayListMultimap.create(); - collectNeuralSparseQueryBuilder(queryBuilder, queryBuilderMap, 1.0f); + queryBuilderMap = collectNeuralSparseQueryBuilder(queryBuilder, 1.0f); if (queryBuilderMap.isEmpty()) return request; + // Make a nestedQueryBuilder which includes all the two-phase QueryBuilder. QueryBuilder nestedTwoPhaseQueryBuilder = getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(queryBuilderMap); + // Add it to the rescorer. nestedTwoPhaseQueryBuilder.boost(getOriginQueryWeightAfterRescore(request.source())); RescorerBuilder twoPhaseRescorer = new QueryRescorerBuilder(nestedTwoPhaseQueryBuilder); int requestSize = request.source().size(); @@ -156,15 +160,15 @@ public NeuralSparseTwoPhaseProcessor create( ) throws IllegalArgumentException { boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, true); - Map map = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY); + Map twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY); float ratio = 0.4f; float window_expansion = 5.0f; int max_window_size = 10000; - if (map != null) { - ratio = ((Number) map.getOrDefault(RATIO_KEY, ratio)).floatValue(); - window_expansion = ((Number) map.getOrDefault(EXPANSION_KEY, window_expansion)).floatValue(); - max_window_size = ((Number) map.getOrDefault(MAX_WINDOW_SIZE_KEY, max_window_size)).intValue(); + if (Objects.nonNull(twoPhaseConfigMap)) { + ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue(); + window_expansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, window_expansion)).floatValue(); + max_window_size = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, max_window_size)).intValue(); } return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, window_expansion, max_window_size); @@ -191,40 +195,25 @@ private float getOriginQueryWeightAfterRescore(final SearchSourceBuilder searchS .reduce(1.0f, (a, b) -> a * b); } - private void collectNeuralSparseQueryBuilder( - final QueryBuilder queryBuilder, - final Multimap neuralSparseQueryBuilderFloatMap, - float baseBoost - ) { - if (queryBuilder instanceof BoolQueryBuilder) { - collectNeuralSparseQueryBuilder((BoolQueryBuilder) queryBuilder, neuralSparseQueryBuilderFloatMap, baseBoost); - } else if (queryBuilder instanceof NeuralSparseQueryBuilder) { - collectNeuralSparseQueryBuilder((NeuralSparseQueryBuilder) queryBuilder, neuralSparseQueryBuilderFloatMap, baseBoost); - } - } + private Multimap collectNeuralSparseQueryBuilder(final QueryBuilder queryBuilder, float baseBoost) { + Multimap result = ArrayListMultimap.create(); - private void collectNeuralSparseQueryBuilder( - final BoolQueryBuilder queryBuilder, - final Multimap neuralSparseQueryBuilderFloatMap, - float baseBoost - ) { - baseBoost *= queryBuilder.boost(); - for (QueryBuilder subQuery : queryBuilder.should()) { - if (subQuery instanceof BoolQueryBuilder) { - collectNeuralSparseQueryBuilder(subQuery, neuralSparseQueryBuilderFloatMap, baseBoost); - } else if (subQuery instanceof NeuralSparseQueryBuilder) { - collectNeuralSparseQueryBuilder((NeuralSparseQueryBuilder) subQuery, neuralSparseQueryBuilderFloatMap, baseBoost); + if (queryBuilder instanceof BoolQueryBuilder) { + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) queryBuilder; + float updatedBoost = baseBoost * boolQueryBuilder.boost(); + for (QueryBuilder subQuery : boolQueryBuilder.should()) { + Multimap subResult = collectNeuralSparseQueryBuilder(subQuery, updatedBoost); + result.putAll(subResult); } + } else if (queryBuilder instanceof NeuralSparseQueryBuilder) { + NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder; + float updatedBoost = baseBoost * neuralSparseQueryBuilder.boost(); + NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio); + result.put(modifiedQueryBuilder, updatedBoost); } - } - - private void collectNeuralSparseQueryBuilder( - final NeuralSparseQueryBuilder queryBuilder, - final Multimap neuralSparseQueryBuilderFloatMap, - float baseBoost - ) { - baseBoost *= queryBuilder.boost(); - neuralSparseQueryBuilderFloatMap.put(queryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio), baseBoost); + // We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will + // do nothing and just quit. + return result; } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index c5de34804..8fb135434 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -83,7 +83,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder> twoPhaseQueryTokensSupplier; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; public static void initialize(MLCommonsClientAccessor mlClient) { @@ -138,7 +138,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float this.queryTokensSupplier(splitTokens.get(true)::get); copy.queryTokensSupplier(splitTokens.get(false)::get); } else copy.queryTokensSupplier(new SetOnce>()::get); - this.twoPhaseNeuralSparseQueryBuilder = copy; + this.twoPhaseQueryTokensSupplier = copy.twoPhaseQueryTokensSupplier(); return copy; } @@ -300,7 +300,7 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { - if (queryTokensSupplier != null) { + if (Objects.nonNull(queryTokensSupplier)) { return this; } validateForRewrite(queryText, modelId); @@ -314,7 +314,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { .maxTokenScore(maxTokenScore) .queryTokensSupplier(queryTokensSetOnce::get) .twoPhasePruneRatio(twoPhasePruneRatio) - .twoPhaseNeuralSparseQueryBuilder(twoPhaseNeuralSparseQueryBuilder); + .twoPhaseQueryTokensSupplier(twoPhaseQueryTokensSupplier); } private BiConsumer> getModelInferenceAsync(SetOnce> setOnce) { @@ -323,13 +323,13 @@ private BiConsumer> getModelInferenceAsync(SetOnce { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); - if (twoPhaseNeuralSparseQueryBuilder != null) { + if (twoPhaseQueryTokensSupplier != null) { Map>> splitSetOnce = getSplitSetOnceByScoreThreshold( queryTokens, twoPhasePruneRatio ); setOnce.set(splitSetOnce.get(true).get()); - twoPhaseNeuralSparseQueryBuilder.queryTokensSupplier(splitSetOnce.get(false)::get); + twoPhaseQueryTokensSupplier = splitSetOnce.get(false)::get; } else { setOnce.set(queryTokens); } @@ -343,8 +343,10 @@ protected Query doToQuery(QueryShardContext context) throws IOException { final MappedFieldType ft = context.fieldMapper(fieldName); validateFieldType(ft); Map queryTokens = queryTokensSupplier.get(); - if (null == queryTokens) { - if (twoPhasePruneRatio == -1f) return new MatchNoDocsQuery(); + if (Objects.isNull(queryTokens)) { + if (twoPhasePruneRatio == -1f) { + return new MatchNoDocsQuery(); + } throw new IllegalArgumentException("Query tokens cannot be null."); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -379,8 +381,8 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (Objects.isNull(obj) || getClass() != obj.getClass()) return false; if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false; if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; - if (Objects.nonNull(twoPhaseNeuralSparseQueryBuilder) && Objects.isNull(obj.twoPhaseNeuralSparseQueryBuilder)) return false; - if (Objects.isNull(twoPhaseNeuralSparseQueryBuilder) && Objects.nonNull(obj.twoPhaseNeuralSparseQueryBuilder)) return false; + if (Objects.nonNull(twoPhaseQueryTokensSupplier) && Objects.isNull(obj.twoPhaseQueryTokensSupplier)) return false; + if (Objects.isNull(twoPhaseQueryTokensSupplier) && Objects.nonNull(obj.twoPhaseQueryTokensSupplier)) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) @@ -390,8 +392,8 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (Objects.nonNull(queryTokensSupplier)) { equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); } - if (Objects.nonNull(twoPhaseNeuralSparseQueryBuilder)) { - equalsBuilder.append(twoPhaseNeuralSparseQueryBuilder, obj.twoPhaseNeuralSparseQueryBuilder); + if (Objects.nonNull(twoPhaseQueryTokensSupplier)) { + equalsBuilder.append(twoPhaseQueryTokensSupplier, obj.twoPhaseQueryTokensSupplier); } return equalsBuilder.isEquals(); } @@ -406,8 +408,8 @@ protected int doHashCode() { if (queryTokensSupplier != null) { builder.append(queryTokensSupplier.get()); } - if (twoPhaseNeuralSparseQueryBuilder != null) { - builder.append(twoPhaseNeuralSparseQueryBuilder); + if (twoPhaseQueryTokensSupplier != null) { + builder.append(twoPhaseQueryTokensSupplier); } return builder.toHashCode(); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java index ac409d9fa..be0ee6eaa 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java @@ -459,7 +459,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then } } - @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java index 54e194938..7478c36e1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java @@ -4,9 +4,14 @@ */ package org.opensearch.neuralsearch.processor; +import lombok.SneakyThrows; import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.SetOnce; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.test.OpenSearchTestCase; import java.util.Collections; @@ -68,6 +73,36 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio assertNotNull(searchRequest.source().rescores()); } + public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(neuralQueryBuilder); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(boolQueryBuilder)); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000); + processor.processRequest(searchRequest); + BoolQueryBuilder queryBuilder = (BoolQueryBuilder) searchRequest.source().query(); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder.should().get(0); + assertEquals(neuralSparseQueryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3); + assertNotNull(searchRequest.source().rescores()); + } + + public void testProcessRequestWithRescorer_whenTwoPhaseEnabled_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + QueryRescorerBuilder queryRescorerBuilder = new QueryRescorerBuilder(new MatchAllQueryBuilder()); + queryRescorerBuilder.setRescoreQueryWeight(0f); + searchRequest.source().addRescorer(queryRescorerBuilder); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000); + processor.processRequest(searchRequest); + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query(); + assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3); + assertNotNull(searchRequest.source().rescores()); + } + public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); @@ -80,6 +115,39 @@ public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Excepti assertNull(searchRequest.source().rescores()); } + @SneakyThrows + public void testProcessRequest_whenTwoPhaseEnabledAndOutOfWindowSize_thenThrowException() { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + QueryRescorerBuilder queryRescorerBuilder = new QueryRescorerBuilder(new MatchAllQueryBuilder()); + queryRescorerBuilder.setRescoreQueryWeight(0f); + searchRequest.source().addRescorer(queryRescorerBuilder); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 400.0f, 100); + expectThrows(IllegalArgumentException.class, () -> processor.processRequest(searchRequest)); + } + + @SneakyThrows + public void testGetSplitSetOnceByScoreThreshold() { + Map queryTokens = new HashMap<>(); + for (int i = 0; i < 10; i++) { + queryTokens.put(String.valueOf(i), (float) i); + } + Map>> splitSetOnce = NeuralSparseTwoPhaseProcessor.getSplitSetOnceByScoreThreshold( + queryTokens, + 0.4f + ); + assertNotNull(splitSetOnce); + SetOnce> upSet = splitSetOnce.get(true); + SetOnce> downSet = splitSetOnce.get(false); + assertNotNull(upSet); + assertEquals(6, upSet.get().size()); + assertNotNull(downSet); + assertEquals(4, downSet.get().size()); + assertNotNull(splitSetOnce.get(false)); + } + public void testType() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory);