diff --git a/CHANGELOG.md b/CHANGELOG.md index b4e0c9309..d8d0cf309 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.13...2.x) ### Features +- Support k-NN radial search parameters in neural search([#697](https://github.com/opensearch-project/neural-search/pull/697)) ### Enhancements - Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670)) - Allowing query by raw tokens in neural_sparse query ([#693](https://github.com/opensearch-project/neural-search/pull/693)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 58562826d..67ba99e66 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,8 +31,8 @@ To send us a pull request, please: 1. Fork the repository. 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. -3. Include tests that check your new feature or bug fix. Ideally, we're looking for unit, integration, and BWC tests, but that depends on how big and critical your change is. -If you're adding an integration test and it is using local ML models, please make sure that the number of model deployments is limited, and you're using the smallest possible model. +3. Include tests that check your new feature or bug fix. Ideally, we're looking for unit, integration, and BWC tests, but that depends on how big and critical your change is. +If you're adding an integration test and it is using local ML models, please make sure that the number of model deployments is limited, and you're using the smallest possible model. Each model deployment consumes resources, and having too many models may cause unexpected test failures. 4. Ensure local tests pass. 5. Commit to your fork using clear commit messages. diff --git a/qa/restart-upgrade/build.gradle b/qa/restart-upgrade/build.gradle index 1a6d0a104..1e63baa35 100644 --- a/qa/restart-upgrade/build.gradle +++ b/qa/restart-upgrade/build.gradle @@ -65,7 +65,7 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { @@ -83,6 +83,13 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the k-NN radial search tests because we introduce this feature in 2.14 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12") || ext.neural_search_bwc_version.startsWith("2.13")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.KnnRadialSearchIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -107,7 +114,7 @@ task testAgainstNewCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.is_old_cluster', 'false' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { @@ -125,6 +132,13 @@ task testAgainstNewCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the k-NN radial search tests because we introduce this feature in 2.14 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12") || ext.neural_search_bwc_version.startsWith("2.13")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.KnnRadialSearchIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java new file mode 100644 index 000000000..8a6dfcde3 --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +public class KnnRadialSearchIT extends AbstractRestartUpgradeRestTestCase { + private static final String PIPELINE_NAME = "radial-search-pipeline"; + private static final String TEST_FIELD = "passage_text"; + private static final String TEST_IMAGE_FIELD = "passage_image"; + private static final String TEXT = "Hello world"; + private static final String TEXT_1 = "Hello world a"; + private static final String TEST_IMAGE_TEXT = "/9j/4AAQSkZJRgABAQAASABIAAD"; + private static final String TEST_IMAGE_TEXT_1 = "/9j/4AAQSkZJRgbdwoeicfhoid"; + + // Test rolling-upgrade with kNN radial search + // Create Text Image Embedding Processor, Ingestion Pipeline and add document + // Validate radial query, pipeline and document count in restart-upgrade scenario + public void testKnnRadialSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + + if (isRunningAgainstOldCluster()) { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineForTextImageProcessor(modelId, PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/IndexMappingMultipleShard.json").toURI())), + PIPELINE_NAME + ); + addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT); + } else { + String modelId = null; + try { + modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_IMAGE_EMBEDDING_PROCESSOR); + loadModel(modelId); + addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_1, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT_1); + validateIndexQuery(modelId); + } finally { + wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, null); + } + } + } + + private void validateIndexQuery(final String modelId) { + NeuralQueryBuilder neuralQueryBuilderWithMinScoreQuery = new NeuralQueryBuilder( + "passage_embedding", + TEXT, + TEST_IMAGE_TEXT, + modelId, + null, + null, + 0.01f, + null, + null + ); + Map responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); + assertNotNull(responseWithMinScoreQuery); + + NeuralQueryBuilder neuralQueryBuilderWithMaxDistanceQuery = new NeuralQueryBuilder( + "passage_embedding", + TEXT, + TEST_IMAGE_TEXT, + modelId, + null, + 100000f, + null, + null, + null + ); + Map responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); + assertNotNull(responseWithMaxDistanceQuery); + } +} diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index e6749d778..cbe210911 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -53,7 +53,17 @@ public void testTextImageEmbeddingProcessor_E2EFlow() throws Exception { private void validateTestIndex(final String modelId) throws Exception { int docCount = getDocCount(getIndexNameForTest()); assertEquals(2, docCount); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder("passage_embedding", TEXT, TEST_IMAGE_TEXT, modelId, 1, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + "passage_embedding", + TEXT, + TEST_IMAGE_TEXT, + modelId, + 1, + null, + null, + null, + null + ); Map response = search(getIndexNameForTest(), neuralQueryBuilder, 1); assertNotNull(response); } diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 591e83d58..609724936 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -83,6 +83,13 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the k-NN radial search tests because we introduce this feature in 2.14 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12") || ext.neural_search_bwc_version.startsWith("2.13")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.KnnRadialSearchIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -126,6 +133,20 @@ task testAgainstOneThirdUpgradedCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + + // Excluding the k-NN radial search tests because we introduce this feature in 2.14 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12") || ext.neural_search_bwc_version.startsWith("2.13")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.KnnRadialSearchIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -150,7 +171,7 @@ task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java new file mode 100644 index 000000000..15be7a15b --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +public class KnnRadialSearchIT extends AbstractRollingUpgradeTestCase { + private static final String PIPELINE_NAME = "radial-search-pipeline"; + private static final String TEST_FIELD = "passage_text"; + private static final String TEST_IMAGE_FIELD = "passage_image"; + private static final String TEXT = "Hello world"; + private static final String TEXT_MIXED = "Hello world mixed"; + private static final String TEXT_UPGRADED = "Hello world upgraded"; + private static final String TEST_IMAGE_TEXT = "/9j/4AAQSkZJRgABAQAASABIAAD"; + private static final String TEST_IMAGE_TEXT_MIXED = "/9j/4AAQSkZJRgbdwoeicfhoid"; + private static final String TEST_IMAGE_TEXT_UPGRADED = "/9j/4AAQSkZJR8eydhgfwceocvlk"; + + private static final int NUM_DOCS_PER_ROUND = 1; + private static String modelId = ""; + + // Test rolling-upgrade with kNN radial search + // Create Text Image Embedding Processor, Ingestion Pipeline and add document + // Validate radial query, pipeline and document count in rolling-upgrade scenario + public void testKnnRadialSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + switch (getClusterType()) { + case OLD: + modelId = uploadTextImageEmbeddingModel(); + loadModel(modelId); + createPipelineForTextImageProcessor(modelId, PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + PIPELINE_NAME + ); + addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT); + break; + case MIXED: + modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_IMAGE_EMBEDDING_PROCESSOR); + int totalDocsCountMixed; + if (isFirstMixedRound()) { + totalDocsCountMixed = NUM_DOCS_PER_ROUND; + validateIndexQueryOnUpgrade(totalDocsCountMixed, modelId, TEXT, TEST_IMAGE_TEXT); + addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT_MIXED); + } else { + totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; + validateIndexQueryOnUpgrade(totalDocsCountMixed, modelId, TEXT_MIXED, TEST_IMAGE_TEXT_MIXED); + } + break; + case UPGRADED: + try { + modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_IMAGE_EMBEDDING_PROCESSOR); + int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; + loadModel(modelId); + addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT_UPGRADED); + validateIndexQueryOnUpgrade(totalDocsCountUpgraded, modelId, TEXT_UPGRADED, TEST_IMAGE_TEXT_UPGRADED); + } finally { + wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, null); + } + break; + default: + throw new IllegalStateException("Unexpected value: " + getClusterType()); + } + } + + private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String modelId, final String text, final String imageText) + throws Exception { + int docCount = getDocCount(getIndexNameForTest()); + assertEquals(numberOfDocs, docCount); + loadModel(modelId); + + NeuralQueryBuilder neuralQueryBuilderWithMinScoreQuery = new NeuralQueryBuilder( + "passage_embedding", + text, + imageText, + modelId, + null, + null, + 0.01f, + null, + null + ); + Map responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); + assertNotNull(responseWithMinScore); + + NeuralQueryBuilder neuralQueryBuilderWithMaxDistanceQuery = new NeuralQueryBuilder( + "passage_embedding", + text, + imageText, + modelId, + null, + 100000f, + null, + null, + null + ); + Map responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); + assertNotNull(responseWithMaxScore); + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index b91ec1322..c0efe0694 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -76,8 +76,18 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod int docCount = getDocCount(getIndexNameForTest()); assertEquals(numberOfDocs, docCount); loadModel(modelId); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder("passage_embedding", text, imageText, modelId, 1, null, null); - Map response = search(getIndexNameForTest(), neuralQueryBuilder, 1); - assertNotNull(response); + NeuralQueryBuilder neuralQueryBuilderWithKQuery = new NeuralQueryBuilder( + "passage_embedding", + text, + imageText, + modelId, + 1, + null, + null, + null, + null + ); + Map responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1); + assertNotNull(responseWithKQuery); } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index d27061e36..986d6d96c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -75,6 +75,12 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField K_FIELD = new ParseField("k"); + @VisibleForTesting + static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + + @VisibleForTesting + static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + private static final int DEFAULT_K = 10; private static MLCommonsClientAccessor ML_CLIENT; @@ -87,13 +93,16 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String queryText; private String queryImage; private String modelId; - private int k = DEFAULT_K; + private Integer k = null; + private Float maxDistance = null; + private Float minScore = null; @VisibleForTesting @Getter(AccessLevel.PACKAGE) @Setter(AccessLevel.PACKAGE) private Supplier vectorSupplier; private QueryBuilder filter; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; + private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; /** * Constructor from stream input @@ -111,8 +120,16 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { } else { this.modelId = in.readString(); } - this.k = in.readVInt(); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + this.k = in.readOptionalInt(); + } else { + this.k = in.readVInt(); + } this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + this.maxDistance = in.readOptionalFloat(); + this.minScore = in.readOptionalFloat(); + } } @Override @@ -125,8 +142,16 @@ protected void doWriteTo(StreamOutput out) throws IOException { } else { out.writeString(this.modelId); } - out.writeVInt(this.k); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + out.writeOptionalInt(this.k); + } else { + out.writeVInt(this.k); + } out.writeOptionalNamedWriteable(this.filter); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + out.writeOptionalFloat(this.maxDistance); + out.writeOptionalFloat(this.minScore); + } } @Override @@ -137,10 +162,18 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } - xContentBuilder.field(K_FIELD.getPreferredName(), k); + if (Objects.nonNull(k)) { + xContentBuilder.field(K_FIELD.getPreferredName(), k); + } if (Objects.nonNull(filter)) { xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); } + if (Objects.nonNull(maxDistance)) { + xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance); + } + if (Objects.nonNull(minScore)) { + xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -193,6 +226,12 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); } + + boolean queryTypeIsProvided = validateKNNQueryType(neuralQueryBuilder); + if (queryTypeIsProvided == false) { + neuralQueryBuilder.k(DEFAULT_K); + } + return neuralQueryBuilder; } @@ -215,6 +254,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n neuralQueryBuilder.queryName(parser.text()); } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.boost(parser.floatValue()); + } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.maxDistance(parser.floatValue()); + } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.minScore(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -246,7 +289,18 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { // create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just // return the current unmodified query builder. if (vectorSupplier() != null) { - return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter()); + if (vectorSupplier().get() == null) { + return this; + } + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter()); + if (maxDistance != null) { + knnQueryBuilder.maxDistance(maxDistance); + } else if (minScore != null) { + knnQueryBuilder.minScore(minScore); + } else { + knnQueryBuilder.k(k); + } + return knnQueryBuilder; } SetOnce vectorSetOnce = new SetOnce<>(); @@ -263,7 +317,17 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter()); + return new NeuralQueryBuilder( + fieldName(), + queryText(), + queryImage(), + modelId(), + k(), + maxDistance(), + minScore(), + vectorSetOnce::get, + filter() + ); } @Override @@ -298,4 +362,25 @@ public String getWriteableName() { private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); } + + private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); + } + + private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) { + int queryCount = 0; + if (neuralQueryBuilder.k() != null) { + queryCount++; + } + if (neuralQueryBuilder.maxDistance() != null) { + queryCount++; + } + if (neuralQueryBuilder.minScore() != null) { + queryCount++; + } + if (queryCount > 1) { + throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided"); + } + return queryCount == 1; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 750278ca3..2199acbf9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -94,6 +94,8 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { modelId, 5, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -142,6 +144,8 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu modelId, 5, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -179,6 +183,8 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() { modelId, 6, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index b3478984c..b7686126b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -248,7 +248,9 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( @@ -297,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -321,7 +323,9 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 175ea08fe..3a72049eb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -90,7 +90,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -140,7 +140,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -190,7 +190,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -215,7 +215,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -240,7 +240,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 7beb02dcc..076d441ce 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -294,7 +294,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); - assertEquals(K, neuralQueryBuilder.k()); + assertEquals(K, (int) neuralQueryBuilder.k()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0f); // verify term query @@ -602,7 +602,7 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery( assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder); KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0); assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK()); assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder); TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1); assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index dd63abbea..804905d73 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -14,6 +14,8 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MIN_SCORE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; @@ -64,7 +66,9 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String QUERY_TEXT = "Hello world!"; private static final String IMAGE_TEXT = "base641234567890"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final int K = 10; + private static final Integer K = 10; + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.985f; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; private static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[10]; @@ -645,7 +649,7 @@ public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder assertTrue(queryBuilder instanceof KNNQueryBuilder); KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK()); assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); } @@ -677,6 +681,104 @@ public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMaxDistance_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "max_distance": float + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(MAX_DISTANCE, neuralQueryBuilder.maxDistance()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMinScore_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "min_score": float + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MIN_SCORE_FIELD.getPreferredName(), MIN_SCORE) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(MIN_SCORE, neuralQueryBuilder.minScore()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMinScoreAndK_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "min_score": float, + "k": int + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = null; + try { + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MIN_SCORE_FIELD.getPreferredName(), MIN_SCORE) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + } catch (IOException e) { + fail("Failed to create XContentBuilder"); + } + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + private void setUpClusterService(Version version) { ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); NeuralSearchClusterUtil.instance().initialize(clusterService); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 9cc9dda71..c49a6a8bf 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -70,9 +70,33 @@ public void setUp() throws Exception { * } * } * } + * and query with radial search max distance and min score: + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "max_distance": 100.0f, + * } + * } + * } + * } + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "min_score": 0.01f, + * } + * } + * } + * } + * */ @SneakyThrows - public void testQueryWithBoostAndImageQuery() { + public void testQueryWithBoostAndImageQueryAndRadialQuery() { String modelId = null; try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); @@ -84,6 +108,8 @@ public void testQueryWithBoostAndImageQuery() { modelId, 1, null, + null, + null, null ); @@ -103,6 +129,8 @@ public void testQueryWithBoostAndImageQuery() { modelId, 1, null, + null, + null, null ); Map searchResponseAsMapMultimodalQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderMultimodalQuery, 1); @@ -115,6 +143,61 @@ public void testQueryWithBoostAndImageQuery() { objectToFloat(firstInnerHitMultimodalQuery.get("_score")), DELTA_FOR_SCORE_ASSERTION ); + + // To save test resources, IT tests for radial search are added below. + // Context: https://github.com/opensearch-project/neural-search/pull/697#discussion_r1571549776 + + // Test radial search max distance query + NeuralQueryBuilder neuralQueryWithMaxDistanceBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + null, + 100.0f, + null, + null, + null + ); + + Map searchResponseAsMapWithMaxDistanceQuery = search( + TEST_BASIC_INDEX_NAME, + neuralQueryWithMaxDistanceBuilder, + 1 + ); + Map firstInnerHitWithMaxDistanceQuery = getFirstInnerHit(searchResponseAsMapWithMaxDistanceQuery); + + assertEquals("1", firstInnerHitWithMaxDistanceQuery.get("_id")); + float expectedScoreWithMaxDistanceQuery = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals( + expectedScoreWithMaxDistanceQuery, + objectToFloat(firstInnerHitWithMaxDistanceQuery.get("_score")), + DELTA_FOR_SCORE_ASSERTION + ); + + // Test radial search min score query + NeuralQueryBuilder neuralQueryWithMinScoreBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + null, + null, + 0.01f, + null, + null + ); + + Map searchResponseAsMapWithMinScoreQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryWithMinScoreBuilder, 1); + Map firstInnerHitWithMinScoreQuery = getFirstInnerHit(searchResponseAsMapWithMinScoreQuery); + + assertEquals("1", firstInnerHitWithMinScoreQuery.get("_id")); + float expectedScoreWithMinScoreQuery = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals( + expectedScoreWithMinScoreQuery, + objectToFloat(firstInnerHitWithMinScoreQuery.get("_score")), + DELTA_FOR_SCORE_ASSERTION + ); } finally { wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); } @@ -154,6 +237,8 @@ public void testRescoreQuery() { modelId, 1, null, + null, + null, null ); @@ -229,6 +314,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( @@ -238,6 +325,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); @@ -263,6 +352,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); @@ -316,6 +407,8 @@ public void testNestedQuery() { modelId, 1, null, + null, + null, null ); @@ -364,6 +457,8 @@ public void testFilterQuery() { modelId, 1, null, + null, + null, new MatchQueryBuilder("_id", "3") ); Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3);