From 8cd348ed545cf1508acaeaedde829603f8440ea3 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 29 Mar 2024 16:31:05 -0400 Subject: [PATCH] Fix contracts --- .../ml/queries/SparseVectorQueryBuilder.java | 16 +++++++------ .../SparseVectorQueryBuilderTests.java | 24 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java index 710451919fb47..5ae57f46d1806 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java @@ -75,15 +75,17 @@ public SparseVectorQueryBuilder( if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); } - if (modelText == null) { - throw new IllegalArgumentException("[" + NAME + "] requires " + MODEL_TEXT.getPreferredName()); - } if ((vectorDimensions == null) == (modelId == null)) { throw new IllegalArgumentException( "[" + NAME + "] requires one of [" + MODEL_ID.getPreferredName() + "], or [" + VECTOR_DIMENSIONS.getPreferredName() + "]" ); } + if (modelId != null && modelText == null) { + throw new IllegalArgumentException( + "[" + NAME + "] requires [" + MODEL_TEXT.getPreferredName() + "] when [" + MODEL_ID.getPreferredName() + "] is specified" + ); + } this.fieldName = fieldName; this.modelText = modelText; @@ -122,6 +124,10 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ADD_SPARSE_VECTOR_QUERY; } + public TokenPruningConfig getTokenPruningConfig() { + return tokenPruningConfig; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { if (vectorDimensionsSupplier != null) { @@ -331,10 +337,6 @@ public static SparseVectorQueryBuilder fromXContent(XContentParser parser) throw } } - if (modelText == null) { - throw new ParsingException(parser.getTokenLocation(), "No text specified for text query"); - } - if (fieldName == null) { throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query"); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java index 0e4d3e3e97d18..e5b7c5f798776 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.Plugin; @@ -42,7 +43,6 @@ import java.util.List; import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_ID; -import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_TEXT; import static org.elasticsearch.xpack.ml.queries.VectorDimensionsQueryBuilder.TOKENS_FIELD; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.either; @@ -195,21 +195,21 @@ public void testIllegalValues() { IllegalArgumentException.class, () -> new SparseVectorQueryBuilder("field name", null, "model id", null) ); - assertEquals("[sparse_vector] requires " + MODEL_TEXT.getPreferredName(), e.getMessage()); + assertEquals("[sparse_vector] requires [model_text] when [model_id] is specified", e.getMessage()); } { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> new SparseVectorQueryBuilder("field name", "model text", null, null) ); - assertEquals( - "[sparse_vector] requires one of [" - + MODEL_ID.getPreferredName() - + "], or [" - + SparseVectorQueryBuilder.VECTOR_DIMENSIONS.getPreferredName() - + "]", - e.getMessage() + assertEquals("[sparse_vector] requires one of [model_id], or [vector_dimensions]", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new SparseVectorQueryBuilder("field name", "model text", "baz", VECTOR_DIMENSIONS) ); + assertEquals("[sparse_vector] requires one of [model_id], or [vector_dimensions]", e.getMessage()); } } @@ -348,6 +348,10 @@ public void testThatTokensAreCorrectlyPruned() { SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(); QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); - assertTrue(rewrittenQueryBuilder instanceof VectorDimensionsQueryBuilder); + if (queryBuilder.getTokenPruningConfig() == null) { + assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); + } else { + assertTrue(rewrittenQueryBuilder instanceof VectorDimensionsQueryBuilder); + } } }