From dff862505f3bb4d546f848f1c461527d4edb4bb5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 31 May 2024 18:26:22 +0200 Subject: [PATCH] ChunkedResult inherits from EmbeddingChunk --- .../ChunkedSparseEmbeddingResults.java | 2 +- .../core/inference/results/Embedding.java | 8 +++- .../inference/results/SparseEmbedding.java | 9 +++++ .../results/ChunkedTextExpansionResults.java | 38 ++++--------------- .../TestSparseInferenceServiceExtension.java | 4 +- .../mapper/SemanticTextFieldTests.java | 1 + .../inference/nlp/TextExpansionProcessor.java | 1 + 7 files changed, 29 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java index 487e56f521be7..51119b2dc5091 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java @@ -146,7 +146,7 @@ public int hashCode() { @Override public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { return chunkedResults.stream() - .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens()))) + .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding().embedding.tokens()))) .iterator(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java index 75eeb6ff01f73..b86d10303e9f2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; @@ -20,7 +21,7 @@ public abstract class Embedding implements Writeable, ToXContentObject { public static final String EMBEDDING = "embedding"; - public interface EmbeddingValues extends ToXContentFragment { + public interface EmbeddingValues extends ToXContentFragment, Writeable { int size(); XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException; @@ -52,6 +53,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public void writeTo(StreamOutput out) throws IOException { + embedding.writeTo(out); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java index 38e9e78065598..843a0026bd2c0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java @@ -97,6 +97,10 @@ public int hashCode() { public static class WeightedTokens implements Embedding.EmbeddingValues { private final List tokens; + public WeightedTokens(StreamInput in) throws IOException { + this.tokens = in.readCollectionAsImmutableList(WeightedToken::new); + } + public WeightedTokens(List tokens) { this.tokens = tokens; } @@ -121,6 +125,11 @@ public XContentBuilder valuesToXContent(String fieldName, XContentBuilder builde return builder; } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(tokens); + } + public List tokens() { return tokens; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ChunkedTextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ChunkedTextExpansionResults.java index f2055e0930fda..b700c2afad066 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ChunkedTextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ChunkedTextExpansionResults.java @@ -9,13 +9,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -24,36 +23,13 @@ public class ChunkedTextExpansionResults extends ChunkedNlpInferenceResults { public static final String NAME = "chunked_text_expansion_result"; - public record ChunkedResult(String matchedText, List weightedTokens) implements Writeable, ToXContentObject { - - public ChunkedResult(StreamInput in) throws IOException { - this(in.readString(), in.readCollectionAsList(WeightedToken::new)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeCollection(weightedTokens); + public static class ChunkedResult extends EmbeddingChunk { + public ChunkedResult(String matchedText, List embeddings) { + super(matchedText, new SparseEmbedding(embeddings, false)); } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TEXT, matchedText); - builder.startObject(INFERENCE); - for (var weightedToken : weightedTokens) { - weightedToken.toXContent(builder, params); - } - builder.endObject(); - builder.endObject(); - return builder; - } - - public Map asMap() { - var map = new HashMap(); - map.put(TEXT, matchedText); - map.put(INFERENCE, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); - return map; + public ChunkedResult(StreamInput in) throws IOException { + super(in.readString(), new SparseEmbedding(in)); } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 38bed844e28fd..65add1c474d67 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -147,7 +147,9 @@ private List makeChunkedResults(List inp tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } results.add( - new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + new ChunkedSparseEmbeddingResults( + List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)) + ) ); } return results; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index b16c1991dc7b3..434822c301f84 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk; import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.model.TestModel; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 1b44614bf4a2b..c9bbe6250ed5f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;