Skip to content

Commit

Permalink
ChunkedResult inherits from EmbeddingChunk
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed May 31, 2024
1 parent ef81185 commit dff8625
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public int hashCode() {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return chunkedResults.stream()
.map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens())))
.map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding().embedding.tokens())))
.iterator();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -20,7 +21,7 @@
public abstract class Embedding<T extends Embedding.EmbeddingValues> implements Writeable, ToXContentObject {
public static final String EMBEDDING = "embedding";

public interface EmbeddingValues extends ToXContentFragment {
public interface EmbeddingValues extends ToXContentFragment, Writeable {
int size();

XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException;
Expand Down Expand Up @@ -52,6 +53,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
embedding.writeTo(out);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ public int hashCode() {
public static class WeightedTokens implements Embedding.EmbeddingValues {
private final List<WeightedToken> tokens;

public WeightedTokens(StreamInput in) throws IOException {
this.tokens = in.readCollectionAsImmutableList(WeightedToken::new);
}

public WeightedTokens(List<WeightedToken> tokens) {
this.tokens = tokens;
}
Expand All @@ -121,6 +125,11 @@ public XContentBuilder valuesToXContent(String fieldName, XContentBuilder builde
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(tokens);
}

public List<WeightedToken> tokens() {
return tokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk;
import org.elasticsearch.xpack.core.inference.results.SparseEmbedding;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -24,36 +23,13 @@
public class ChunkedTextExpansionResults extends ChunkedNlpInferenceResults {
public static final String NAME = "chunked_text_expansion_result";

public record ChunkedResult(String matchedText, List<WeightedToken> weightedTokens) implements Writeable, ToXContentObject {

public ChunkedResult(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsList(WeightedToken::new));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(weightedTokens);
public static class ChunkedResult extends EmbeddingChunk<SparseEmbedding.WeightedTokens> {
public ChunkedResult(String matchedText, List<WeightedToken> embeddings) {
super(matchedText, new SparseEmbedding(embeddings, false));
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TEXT, matchedText);
builder.startObject(INFERENCE);
for (var weightedToken : weightedTokens) {
weightedToken.toXContent(builder, params);
}
builder.endObject();
builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
var map = new HashMap<String, Object>();
map.put(TEXT, matchedText);
map.put(INFERENCE, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
return map;
public ChunkedResult(StreamInput in) throws IOException {
super(in.readString(), new SparseEmbedding(in));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j)));
}
results.add(
new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)))
new ChunkedSparseEmbeddingResults(
List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))
)
);
}
return results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk;
import org.elasticsearch.xpack.core.inference.results.FloatEmbedding;
import org.elasticsearch.xpack.core.inference.results.SparseEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.model.TestModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
Expand Down

0 comments on commit dff8625

Please sign in to comment.