From 914ce1cb84c04e73d1a91143a37b189cd8cc9148 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 16 Apr 2024 17:10:36 +0200 Subject: [PATCH] PR feedback --- .../inference/mapper/SemanticTextField.java | 89 +++---------------- .../mapper/SemanticTextFieldMapper.java | 2 +- .../mapper/SemanticTextFieldMapperTests.java | 6 +- .../mapper/SemanticTextFieldTests.java | 77 +++++++++++++++- 4 files changed, 93 insertions(+), 81 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 1c5c56b7b9da6..2184c94f2b73b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -8,35 +8,26 @@ package org.elasticsearch.xpack.inference.mapper; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Tuple; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -100,6 +91,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.endObject(); } + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append("task_type=").append(taskType); + if (dimensions != null) { + sb.append(", dimensions=").append(dimensions); + } + if (similarity != null) { + sb.append(", similarity=").append(similarity); + } + return sb.toString(); + } + private void validate() { switch (taskType) { case TEXT_EMBEDDING: @@ -257,71 +261,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } - - /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. - */ - public static List toSemanticTextFieldChunks( - String field, - String inferenceId, - List results, - XContentType contentType - ) { - List chunks = new ArrayList<>(); - for (var result : results) { - if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); - } - } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - inferenceId, - result.getWriteableName() - ); - } - } - return chunks; - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, double[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (double v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } - - /** - * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index a515cbf03741c..f23b6f3a66899 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -90,7 +90,7 @@ public static class Builder extends FieldMapper.Builder { (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings, XContentBuilder::field, - (m) -> m == null ? "null" : Strings.toString(m) + Objects::toString ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 9392a3184e3ac..2e05ed38c5a81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -264,7 +264,7 @@ public void testUpdateModelSettings() throws IOException { ); assertThat( exc.getMessage(), - containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") + containsString("Cannot update parameter [model_settings] " + "from [task_type=sparse_embedding] to [null]") ); } { @@ -289,8 +289,8 @@ public void testUpdateModelSettings() throws IOException { exc.getMessage(), containsString( "Cannot update parameter [model_settings] " - + "from [{\"task_type\":\"sparse_embedding\"}] " - + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" + + "from [task_type=sparse_embedding] " + + "to [task_type=text_embedding, dimensions=10, similarity=cosine]" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index c5ab6eb1f9e15..05a7fbf280deb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -7,13 +7,18 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; @@ -30,7 +35,6 @@ import java.util.function.Predicate; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.hamcrest.Matchers.equalTo; public class SemanticTextFieldTests extends AbstractXContentTestCase { @@ -143,6 +147,77 @@ public static SemanticTextField randomSemanticText(String fieldName, Model model ); } + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link SemanticTextField.Chunk}. + */ + private static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add( + new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens())) + ); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add( + new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding())) + ); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + public static Model randomModel(TaskType taskType) { String serviceName = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomAlphaOfLengthBetween(5, 10);