From cf2d2ecc730d7ab989164f155f9fd280bdc7b9bb Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Thu, 18 Apr 2024 14:27:12 -0700 Subject: [PATCH] Support k-NN radial search parameters in neural search --- .../query/NeuralQueryBuilder.java | 86 +++++++++++++- .../processor/NormalizationProcessorIT.java | 6 + .../processor/ScoreCombinationIT.java | 12 +- .../processor/ScoreNormalizationIT.java | 12 +- .../query/HybridQueryBuilderTests.java | 4 +- .../query/NeuralQueryBuilderTests.java | 106 +++++++++++++++++- .../neuralsearch/query/NeuralQueryIT.java | 74 ++++++++++++ 7 files changed, 282 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index d27061e36..6a5b10857 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -75,6 +75,12 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField K_FIELD = new ParseField("k"); + @VisibleForTesting + static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + + @VisibleForTesting + static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + private static final int DEFAULT_K = 10; private static MLCommonsClientAccessor ML_CLIENT; @@ -87,13 +93,16 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String queryText; private String queryImage; private String modelId; - private int k = DEFAULT_K; + private Integer k = null; + private Float max_distance = null; + private Float min_score = null; @VisibleForTesting @Getter(AccessLevel.PACKAGE) @Setter(AccessLevel.PACKAGE) private Supplier vectorSupplier; private QueryBuilder filter; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; + private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; /** * Constructor from stream input @@ -113,6 +122,10 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { } this.k = in.readVInt(); this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + this.max_distance = in.readOptionalFloat(); + this.min_score = in.readOptionalFloat(); + } } @Override @@ -127,6 +140,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { } out.writeVInt(this.k); out.writeOptionalNamedWriteable(this.filter); + if (isClusterOnOrAfterMinReqVersionForRadialSearch()) { + out.writeOptionalFloat(this.max_distance); + out.writeOptionalFloat(this.min_score); + } } @Override @@ -137,10 +154,18 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } - xContentBuilder.field(K_FIELD.getPreferredName(), k); + if (Objects.nonNull(k)) { + xContentBuilder.field(K_FIELD.getPreferredName(), k); + } if (Objects.nonNull(filter)) { xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); } + if (Objects.nonNull(max_distance)) { + xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); + } + if (Objects.nonNull(min_score)) { + xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -193,6 +218,12 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); } + + long queryCountProvided = validateKNNQueryType(neuralQueryBuilder); + if (queryCountProvided == 0) { + neuralQueryBuilder.k(DEFAULT_K); + } + return neuralQueryBuilder; } @@ -215,6 +246,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n neuralQueryBuilder.queryName(parser.text()); } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.boost(parser.floatValue()); + } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.max_distance(parser.floatValue()); + } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.min_score(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -246,7 +281,19 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { // create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just // return the current unmodified query builder. if (vectorSupplier() != null) { - return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter()); + if (vectorSupplier().get() == null) { + return this; + } else { + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter()); + if (max_distance != null) { + knnQueryBuilder.maxDistance(max_distance); + } else if (min_score != null) { + knnQueryBuilder.minScore(min_score); + } else { + knnQueryBuilder.k(k); + } + return knnQueryBuilder; + } } SetOnce vectorSetOnce = new SetOnce<>(); @@ -263,7 +310,17 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter()); + return new NeuralQueryBuilder( + fieldName(), + queryText(), + queryImage(), + modelId(), + k(), + max_distance(), + min_score(), + vectorSetOnce::get, + filter() + ); } @Override @@ -298,4 +355,25 @@ public String getWriteableName() { private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); } + + private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); + } + + private static int validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) { + int queryCountProvided = 0; + if (neuralQueryBuilder.k() != null) { + queryCountProvided++; + } + if (neuralQueryBuilder.max_distance() != null) { + queryCountProvided++; + } + if (neuralQueryBuilder.min_score() != null) { + queryCountProvided++; + } + if (queryCountProvided > 1) { + throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided"); + } + return queryCountProvided; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 750278ca3..2199acbf9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -94,6 +94,8 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { modelId, 5, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -142,6 +144,8 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu modelId, 5, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -179,6 +183,8 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() { modelId, 6, null, + null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index b3478984c..b7686126b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -248,7 +248,9 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( @@ -297,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -321,7 +323,9 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) + ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 175ea08fe..3a72049eb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -90,7 +90,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -140,7 +140,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -190,7 +190,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -215,7 +215,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -240,7 +240,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 7beb02dcc..076d441ce 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -294,7 +294,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); - assertEquals(K, neuralQueryBuilder.k()); + assertEquals(K, (int) neuralQueryBuilder.k()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0f); // verify term query @@ -602,7 +602,7 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery( assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder); KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0); assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK()); assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder); TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1); assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index dd63abbea..8e954db1f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -14,6 +14,8 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MAX_DISTANCE_FIELD; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MIN_SCORE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; @@ -64,7 +66,9 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String QUERY_TEXT = "Hello world!"; private static final String IMAGE_TEXT = "base641234567890"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final int K = 10; + private static final Integer K = 10; + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.985f; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; private static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[10]; @@ -645,7 +649,7 @@ public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder assertTrue(queryBuilder instanceof KNNQueryBuilder); KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK()); assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); } @@ -677,6 +681,104 @@ public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMaxDistance_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "max_distance": float + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(MAX_DISTANCE, neuralQueryBuilder.max_distance()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMinScore_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "min_score": float + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MIN_SCORE_FIELD.getPreferredName(), MIN_SCORE) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(MIN_SCORE, neuralQueryBuilder.min_score()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMinScoreAndK_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "min_score": float, + "k": int + } + } + */ + setUpClusterService(Version.V_2_14_0); + XContentBuilder xContentBuilder = null; + try { + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MIN_SCORE_FIELD.getPreferredName(), MIN_SCORE) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + } catch (IOException e) { + fail("Failed to create XContentBuilder"); + } + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + private void setUpClusterService(Version version) { ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); NeuralSearchClusterUtil.instance().initialize(clusterService); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 9cc9dda71..4a9f83874 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -84,6 +84,8 @@ public void testQueryWithBoostAndImageQuery() { modelId, 1, null, + null, + null, null ); @@ -103,6 +105,8 @@ public void testQueryWithBoostAndImageQuery() { modelId, 1, null, + null, + null, null ); Map searchResponseAsMapMultimodalQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderMultimodalQuery, 1); @@ -154,6 +158,8 @@ public void testRescoreQuery() { modelId, 1, null, + null, + null, null ); @@ -229,6 +235,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( @@ -238,6 +246,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); @@ -263,6 +273,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { modelId, 1, null, + null, + null, null ); @@ -316,6 +328,8 @@ public void testNestedQuery() { modelId, 1, null, + null, + null, null ); @@ -364,6 +378,8 @@ public void testFilterQuery() { modelId, 1, null, + null, + null, new MatchQueryBuilder("_id", "3") ); Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3); @@ -377,6 +393,64 @@ public void testFilterQuery() { } } + @SneakyThrows + public void testQueryWithMaxDistance() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareModel(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + null, + 100.0f, + null, + null, + null + ); + + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + + @SneakyThrows + public void testQueryWithMinScore() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + modelId = prepareModel(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + "", + modelId, + null, + null, + 0.01f, + null, + null + ); + + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {