From 935829aa65c45f1e6a1cfe65e17eaeb4383dc1d5 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Wed, 3 Jul 2024 20:10:23 -0700 Subject: [PATCH] Adds method_parameters in neural search query to support ef_search (#787) (#814) Signed-off-by: Tejas Shah --- CHANGELOG.md | 1 + .../neuralsearch/bwc/HybridSearchIT.java | 15 ++++-- .../neuralsearch/bwc/KnnRadialSearchIT.java | 2 + .../neuralsearch/bwc/MultiModalSearchIT.java | 1 + .../neuralsearch/bwc/HybridSearchIT.java | 13 ++++- .../neuralsearch/bwc/KnnRadialSearchIT.java | 2 + .../neuralsearch/bwc/MultiModalSearchIT.java | 1 + .../common/MinClusterVersionUtil.java | 50 +++++++++++++++++++ .../query/NeuralQueryBuilder.java | 44 ++++++++-------- .../processor/NormalizationProcessorIT.java | 3 ++ .../processor/ScoreCombinationIT.java | 8 +-- .../processor/ScoreNormalizationIT.java | 12 ++--- .../query/NeuralQueryBuilderTests.java | 39 ++++++++++++++- .../neuralsearch/query/NeuralQueryIT.java | 14 +++++- 14 files changed, 165 insertions(+), 40 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a72fcdaa..16a4669c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x) ### Features ### Enhancements +* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index f5289fe79..845396dd0 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -10,6 +10,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; + import org.opensearch.index.query.MatchQueryBuilder; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; @@ -69,6 +70,7 @@ private void validateNormalizationProcessor(final String fileName, final String loadModel(modelId); addDocuments(getIndexNameForTest(), false); validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName); + validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100)); } finally { wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName); } @@ -96,10 +98,14 @@ private void createSearchPipeline(final String pipelineName) { ); } - private void validateTestIndex(final String modelId, final String index, final String searchPipeline) throws Exception { + private void validateTestIndex(final String modelId, final String index, final String searchPipeline) { + validateTestIndex(modelId, index, searchPipeline, null); + } + + private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map methodParameters) { int docCount = getDocCount(index); assertEquals(6, docCount); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters); Map searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline)); assertNotNull(searchResponseAsMap); int hits = getHitCount(searchResponseAsMap); @@ -110,12 +116,15 @@ private void validateTestIndex(final String modelId, final String index, final S } } - private HybridQueryBuilder getQueryBuilder(final String modelId) { + private HybridQueryBuilder getQueryBuilder(final String modelId, Map methodParameters) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); neuralQueryBuilder.queryText(QUERY); neuralQueryBuilder.k(5); + if (methodParameters != null) { + neuralQueryBuilder.methodParameters(methodParameters); + } MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY); diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index 8a6dfcde3..ece2bbb9e 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -60,6 +60,7 @@ private void validateIndexQuery(final String modelId) { null, 0.01f, null, + null, null ); Map responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); @@ -74,6 +75,7 @@ private void validateIndexQuery(final String modelId) { 100000f, null, null, + null, null ); Map responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index afa29bab5..54d993b35 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -62,6 +62,7 @@ private void validateTestIndex(final String modelId) throws Exception { null, null, null, + null, null ); Map response = search(getIndexNameForTest(), neuralQueryBuilder, 1); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 903ffc9be..ba2ff7979 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -73,6 +73,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100)); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); } @@ -83,10 +84,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr } private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception { + validateTestIndexOnUpgrade(numberOfDocs, modelId, null); + } + + private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map methodParameters) + throws Exception { int docCount = getDocCount(getIndexNameForTest()); assertEquals(numberOfDocs, docCount); loadModel(modelId); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters); Map searchResponseAsMap = search( getIndexNameForTest(), hybridQueryBuilder, @@ -103,12 +109,15 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod } } - private HybridQueryBuilder getQueryBuilder(final String modelId) { + private HybridQueryBuilder getQueryBuilder(final String modelId, final Map methodParameters) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); neuralQueryBuilder.queryText(QUERY); neuralQueryBuilder.k(5); + if (methodParameters != null) { + neuralQueryBuilder.methodParameters(methodParameters); + } MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index 15be7a15b..17d15898b 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -86,6 +86,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo null, 0.01f, null, + null, null ); Map responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); @@ -100,6 +101,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo 100000f, null, null, + null, null ); Map responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index 1154f1e51..8e0ff7568 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -85,6 +85,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod null, null, null, + null, null ); Map responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1); diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java new file mode 100644 index 000000000..160b2fa4d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.common; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; + +import java.util.Map; + +import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; + +/** + * A util class which holds the logic to determine the min version supported by the request parameters + */ +public final class MinClusterVersionUtil { + + private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; + private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; + + // Note this minimal version will act as a override + private static final Map MINIMAL_VERSION_NEURAL = ImmutableMap.builder() + .put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID) + .put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH) + .put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH) + .build(); + + public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); + } + + public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); + } + + public static boolean isClusterOnOrAfterMinReqVersion(String key) { + Version version; + if (MINIMAL_VERSION_NEURAL.containsKey(key)) { + version = MINIMAL_VERSION_NEURAL.get(key); + } else { + version = IndexUtil.minimalRequiredVersionMap.get(key); + } + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(version); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index e7e081f2b..8e1b6b36b 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -5,6 +5,12 @@ package org.opensearch.neuralsearch.query; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; @@ -19,7 +25,6 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; -import org.opensearch.Version; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionListener; @@ -34,8 +39,9 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.parser.MethodParametersParser; +import org.opensearch.neuralsearch.common.MinClusterVersionUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import com.google.common.annotations.VisibleForTesting; @@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image"); - @VisibleForTesting - static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); + public static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); @VisibleForTesting static final ParseField K_FIELD = new ParseField("k"); - @VisibleForTesting - static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); - - @VisibleForTesting - static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); - private static final int DEFAULT_K = 10; private static MLCommonsClientAccessor ML_CLIENT; @@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { @Setter(AccessLevel.PACKAGE) private Supplier vectorSupplier; private QueryBuilder filter; - private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; - private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; + private Map methodParameters; /** * Constructor from stream input @@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { this.maxDistance = in.readOptionalFloat(); this.minScore = in.readOptionalFloat(); } + if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { + this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); + } } @Override @@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalFloat(this.maxDistance); out.writeOptionalFloat(this.minScore); } + if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { + MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); + } } @Override @@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (Objects.nonNull(minScore)) { xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } + if (Objects.nonNull(methodParameters)) { + MethodParametersParser.doXContent(xContentBuilder, methodParameters); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } else if (token == XContentParser.Token.START_OBJECT) { if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.filter(parseInnerQueryBuilder(parser)); + } else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser)); } } else { throw new ParsingException( @@ -325,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { maxDistance(), minScore(), vectorSetOnce::get, - filter() + filter(), + methodParameters() ); } @@ -358,14 +368,6 @@ public String getWriteableName() { return NAME; } - private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { - return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); - } - - private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { - return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); - } - private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) { int queryCount = 0; if (neuralQueryBuilder.k() != null) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index a34863ee3..7477fe63b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -96,6 +96,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -146,6 +147,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -185,6 +187,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() { null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index 800dc6129..ad2460103 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -249,7 +249,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -299,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -324,7 +324,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 9f201b4bd..7700c9f6a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -85,7 +85,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -110,7 +110,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -135,7 +135,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -185,7 +185,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -210,7 +210,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -235,7 +235,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index f3c763764..9ecb93b81 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -12,10 +12,10 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; -import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MAX_DISTANCE_FIELD; -import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MIN_SCORE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; @@ -107,6 +107,41 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { assertEquals(K, neuralQueryBuilder.k()); } + @SneakyThrows + public void testFromXContent_withMethodParameters_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "k": int + } + } + */ + setUpClusterService(Version.V_2_10_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .startObject("method_parameters") + .field("ef_search", 1000) + .endObject() + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(K, neuralQueryBuilder.k()); + } + @SneakyThrows public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { /* diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index b17f7f151..0e5d86e72 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -110,6 +110,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, null, null, + null, null ); @@ -131,7 +132,8 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, null, null, - null + null, + Map.of("ef_search", 10) ); Map searchResponseAsMapMultimodalQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderMultimodalQuery, 1); Map firstInnerHitMultimodalQuery = getFirstInnerHit(searchResponseAsMapMultimodalQuery); @@ -157,6 +159,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { 100.0f, null, null, + null, null ); @@ -185,6 +188,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, 0.01f, null, + null, null ); @@ -239,6 +243,7 @@ public void testRescoreQuery() { null, null, null, + null, null ); @@ -316,6 +321,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( @@ -327,6 +333,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); @@ -354,6 +361,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); @@ -409,6 +417,7 @@ public void testNestedQuery() { null, null, null, + null, null ); @@ -459,7 +468,8 @@ public void testFilterQuery() { null, null, null, - new MatchQueryBuilder("_id", "3") + new MatchQueryBuilder("_id", "3"), + null ); Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3); assertEquals(1, getHitCount(searchResponseAsMap));