Skip to content

Commit

Permalink
Fix contracts
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Mar 29, 2024
1 parent d3f3d52 commit 8cd348e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ public SparseVectorQueryBuilder(
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a fieldName");
}
if (modelText == null) {
throw new IllegalArgumentException("[" + NAME + "] requires " + MODEL_TEXT.getPreferredName());
}

if ((vectorDimensions == null) == (modelId == null)) {
throw new IllegalArgumentException(
"[" + NAME + "] requires one of [" + MODEL_ID.getPreferredName() + "], or [" + VECTOR_DIMENSIONS.getPreferredName() + "]"
);
}
if (modelId != null && modelText == null) {
throw new IllegalArgumentException(
"[" + NAME + "] requires [" + MODEL_TEXT.getPreferredName() + "] when [" + MODEL_ID.getPreferredName() + "] is specified"
);
}

this.fieldName = fieldName;
this.modelText = modelText;
Expand Down Expand Up @@ -122,6 +124,10 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ADD_SPARSE_VECTOR_QUERY;
}

public TokenPruningConfig getTokenPruningConfig() {
return tokenPruningConfig;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (vectorDimensionsSupplier != null) {
Expand Down Expand Up @@ -331,10 +337,6 @@ public static SparseVectorQueryBuilder fromXContent(XContentParser parser) throw
}
}

if (modelText == null) {
throw new ParsingException(parser.getTokenLocation(), "No text specified for text query");
}

if (fieldName == null) {
throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
Expand All @@ -42,7 +43,6 @@
import java.util.List;

import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_ID;
import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_TEXT;
import static org.elasticsearch.xpack.ml.queries.VectorDimensionsQueryBuilder.TOKENS_FIELD;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.Matchers.either;
Expand Down Expand Up @@ -195,21 +195,21 @@ public void testIllegalValues() {
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", null, "model id", null)
);
assertEquals("[sparse_vector] requires " + MODEL_TEXT.getPreferredName(), e.getMessage());
assertEquals("[sparse_vector] requires [model_text] when [model_id] is specified", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", "model text", null, null)
);
assertEquals(
"[sparse_vector] requires one of ["
+ MODEL_ID.getPreferredName()
+ "], or ["
+ SparseVectorQueryBuilder.VECTOR_DIMENSIONS.getPreferredName()
+ "]",
e.getMessage()
assertEquals("[sparse_vector] requires one of [model_id], or [vector_dimensions]", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", "model text", "baz", VECTOR_DIMENSIONS)
);
assertEquals("[sparse_vector] requires one of [model_id], or [vector_dimensions]", e.getMessage());
}
}

Expand Down Expand Up @@ -348,6 +348,10 @@ public void testThatTokensAreCorrectlyPruned() {
SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder();
QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext);
assertTrue(rewrittenQueryBuilder instanceof VectorDimensionsQueryBuilder);
if (queryBuilder.getTokenPruningConfig() == null) {
assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder);
} else {
assertTrue(rewrittenQueryBuilder instanceof VectorDimensionsQueryBuilder);
}
}
}

0 comments on commit 8cd348e

Please sign in to comment.