Skip to content

Commit

Permalink
[ML] Default chunked inference to regular inference call for 3rd part…
Browse files Browse the repository at this point in the history
…y services (#106786)

* Working chunked call for 3rd party services

* Adding some java doc and clean up

* Fixing test

* Moving function to tests

* Adding comments
  • Loading branch information
jonathan-buttner authored Mar 27, 2024
1 parent 50dcfdc commit 2fcd3c2
Show file tree
Hide file tree
Showing 17 changed files with 1,046 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,51 @@
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public class ChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_sparse_embedding_results";
public static final String FIELD_NAME = "sparse_embedding_chunk";

public static ChunkedSparseEmbeddingResults ofMlResult(ChunkedTextExpansionResults mlInferenceResults) {
return new ChunkedSparseEmbeddingResults(mlInferenceResults.getChunks());
}

/**
* Returns a list of {@link ChunkedSparseEmbeddingResults}. The number of entries in the list will match the input list size.
* Each {@link ChunkedSparseEmbeddingResults} will have a single chunk containing the entire results from the
* {@link SparseEmbeddingResults}.
*/
public static List<ChunkedInferenceServiceResults> of(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());

var results = new ArrayList<ChunkedInferenceServiceResults>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
results.add(of(inputs.get(i), sparseEmbeddingResults.embeddings().get(i)));
}

return results;
}

public static ChunkedSparseEmbeddingResults of(String input, SparseEmbeddingResults.Embedding embedding) {
var weightedTokens = embedding.tokens()
.stream()
.map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token(), weightedToken.weight()))
.toList();

return new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input, weightedTokens)));
}

private final List<ChunkedTextExpansionResults.ChunkedResult> chunkedResults;

public ChunkedSparseEmbeddingResults(List<ChunkedTextExpansionResults.ChunkedResult> chunks) {
Expand All @@ -43,7 +74,7 @@ public List<ChunkedTextExpansionResults.ChunkedResult> getChunkedResults() {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray("sparse_embedding_chunk");
builder.startArray(FIELD_NAME);
for (ChunkedTextExpansionResults.ChunkedResult chunk : chunkedResults) {
chunk.toXContent(builder, params);
}
Expand Down Expand Up @@ -73,7 +104,10 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Chunked results are not returned in the a map format");
return Map.of(
FIELD_NAME,
chunkedResults.stream().map(ChunkedTextExpansionResults.ChunkedResult::asMap).collect(Collectors.toList())
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public record ChunkedTextEmbeddingByteResults(List<EmbeddingChunk> chunks, boolean isTruncated) implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_text_embedding_service_byte_results";
public static final String FIELD_NAME = "text_embedding_byte_chunk";

/**
* Returns a list of {@link ChunkedTextEmbeddingByteResults}. The number of entries in the list will match the input list size.
* Each {@link ChunkedTextEmbeddingByteResults} will have a single chunk containing the entire results from the
* {@link TextEmbeddingByteResults}.
*/
public static List<ChunkedInferenceServiceResults> of(List<String> inputs, TextEmbeddingByteResults textEmbeddings) {
validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size());

var results = new ArrayList<ChunkedInferenceServiceResults>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
results.add(of(inputs.get(i), textEmbeddings.embeddings().get(i).values()));
}

return results;
}

public static ChunkedTextEmbeddingByteResults of(String input, List<Byte> byteEmbeddings) {
return new ChunkedTextEmbeddingByteResults(List.of(new EmbeddingChunk(input, byteEmbeddings)), false);
}

public ChunkedTextEmbeddingByteResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(EmbeddingChunk::new), in.readBoolean());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
// TODO add isTruncated flag
builder.startArray(FIELD_NAME);
for (var embedding : chunks) {
embedding.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(chunks);
out.writeBoolean(isTruncated);
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the legacy format");
}

@Override
public Map<String, Object> asMap() {
return Map.of(FIELD_NAME, chunks.stream().map(EmbeddingChunk::asMap).collect(Collectors.toList()));
}

@Override
public String getWriteableName() {
return NAME;
}

public List<EmbeddingChunk> getChunks() {
return chunks;
}

public record EmbeddingChunk(String matchedText, List<Byte> embedding) implements Writeable, ToXContentObject {

public EmbeddingChunk(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsImmutableList(StreamInput::readByte));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(embedding, StreamOutput::writeByte);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ChunkedNlpInferenceResults.TEXT, matchedText);

builder.startArray(ChunkedNlpInferenceResults.INFERENCE);
for (Byte value : embedding) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
var map = new HashMap<String, Object>();
map.put(ChunkedNlpInferenceResults.TEXT, matchedText);
map.put(ChunkedNlpInferenceResults.INFERENCE, embedding);
return map;
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,56 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public class ChunkedTextEmbeddingResults implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_text_embedding_service_results";

public static final String FIELD_NAME = "text_embedding_chunk";

public static ChunkedTextEmbeddingResults ofMlResult(
org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults mlInferenceResults
) {
return new ChunkedTextEmbeddingResults(mlInferenceResults.getChunks());
}

/**
* Returns a list of {@link ChunkedTextEmbeddingResults}. The number of entries in the list will match the input list size.
* Each {@link ChunkedTextEmbeddingResults} will have a single chunk containing the entire results from the
* {@link TextEmbeddingResults}.
*/
public static List<ChunkedInferenceServiceResults> of(List<String> inputs, TextEmbeddingResults textEmbeddings) {
validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size());

var results = new ArrayList<ChunkedInferenceServiceResults>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
results.add(ChunkedTextEmbeddingResults.of(inputs.get(i), textEmbeddings.embeddings().get(i).values()));
}

return results;
}

public static ChunkedTextEmbeddingResults of(String input, List<Float> floatEmbeddings) {
double[] doubleEmbeddings = floatEmbeddings.stream().mapToDouble(ChunkedTextEmbeddingResults::floatToDouble).toArray();

return new ChunkedTextEmbeddingResults(
List.of(
new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, doubleEmbeddings)
)
);
}

private static double floatToDouble(Float aFloat) {
return aFloat != null ? aFloat : 0;
}

private final List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk> chunks;

public ChunkedTextEmbeddingResults(
Expand All @@ -48,7 +84,8 @@ public List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddi

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray("text_embedding_chunk");
// TODO add isTruncated flag
builder.startArray(FIELD_NAME);
for (var embedding : chunks) {
embedding.toXContent(builder, params);
}
Expand All @@ -68,7 +105,7 @@ public String getWriteableName() {

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action");
throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action");
}

@Override
Expand All @@ -78,7 +115,12 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Chunked results are not returned in the a map format");
return Map.of(
FIELD_NAME,
chunks.stream()
.map(org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk::asMap)
.collect(Collectors.toList())
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;

public class ResultUtils {

public static ElasticsearchStatusException createInvalidChunkedResultException(String receivedResultName) {
return new ElasticsearchStatusException(
"Expected a chunked inference [{}] received [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
ChunkedTextEmbeddingResults.NAME,
receivedResultName
);
}

private ResultUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;

import java.util.List;

public class TextEmbeddingUtils {
Expand All @@ -25,6 +27,18 @@ public static int getFirstEmbeddingSize(List<EmbeddingInt> embeddings) throws Il
return embeddings.get(0).getSize();
}

/**
* Throws an exception if the number of elements in the input text list is different than the results in text embedding
* response.
*/
static void validateInputSizeAgainstEmbeddings(List<String> inputs, int embeddingSize) {
if (inputs.size() != embeddingSize) {
throw new IllegalArgumentException(
Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputs.size(), embeddingSize)
);
}
}

private TextEmbeddingUtils() {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

public abstract class ChunkedNlpInferenceResults extends NlpInferenceResults {

static String TEXT = "text";
static String INFERENCE = "inference";
public static String TEXT = "text";
public static String INFERENCE = "inference";

ChunkedNlpInferenceResults(boolean isTruncated) {
super(isTruncated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.INFERENCE;
import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.TEXT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;

public class ChunkedTextEmbeddingResultsTests extends AbstractWireSerializingTestCase<ChunkedTextEmbeddingResults> {
Expand All @@ -33,6 +39,17 @@ public static ChunkedTextEmbeddingResults createRandomResults() {
return new ChunkedTextEmbeddingResults(DEFAULT_RESULTS_FIELD, chunks, randomBoolean());
}

/**
* Similar to {@link ChunkedTextEmbeddingResults.EmbeddingChunk#asMap()} but it converts the double array into a list of doubles to
* make testing equality easier.
*/
public static Map<String, Object> asMapWithListsInsteadOfArrays(ChunkedTextEmbeddingResults.EmbeddingChunk chunk) {
var map = new HashMap<String, Object>();
map.put(TEXT, chunk.matchedText());
map.put(INFERENCE, Arrays.stream(chunk.embedding()).boxed().collect(Collectors.toList()));
return map;
}

@Override
protected Writeable.Reader<ChunkedTextEmbeddingResults> instanceReader() {
return ChunkedTextEmbeddingResults::new;
Expand Down
Loading

0 comments on commit 2fcd3c2

Please sign in to comment.