-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding support for Cohere inference service #104559
Changes from all commits
244da75
e1008cf
f94c32c
58c707d
25dcec1
153785a
4a127ff
3ebed6d
a4e576f
b4e2443
82001b2
83caa87
bcf193d
8102408
6c8343f
de682b5
e841368
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 104559 | ||
summary: Adding support for Cohere inference service | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) { | |
this.secrets = Objects.requireNonNull(secrets); | ||
} | ||
|
||
public Model(Model model, TaskSettings taskSettings) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These just make it easier to construct the |
||
Objects.requireNonNull(model); | ||
|
||
configurations = ModelConfigurations.of(model, taskSettings); | ||
secrets = model.getSecrets(); | ||
} | ||
|
||
public Model(Model model, ServiceSettings serviceSettings) { | ||
Objects.requireNonNull(model); | ||
|
||
configurations = ModelConfigurations.of(model, serviceSettings); | ||
secrets = model.getSecrets(); | ||
} | ||
|
||
public Model(ModelConfigurations configurations) { | ||
this(configurations, new ModelSecrets()); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.inference.results; | ||
|
||
public interface EmbeddingInt { | ||
int getSize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm open to other ideas here. I did this to make determining the embedding size easier within the |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.inference.results; | ||
|
||
public interface TextEmbedding { | ||
|
||
/** | ||
* Returns the first text embedding entry in the result list's array size. | ||
* @return the size of the text embedding | ||
* @throws IllegalStateException if the list of embeddings is empty | ||
*/ | ||
int getFirstEmbeddingSize() throws IllegalStateException; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also needed for simplifying |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.inference.results; | ||
|
||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.inference.InferenceResults; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.inference.TaskType; | ||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.LinkedHashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
/** | ||
* Writes a text embedding result in the follow json format | ||
* { | ||
* "text_embedding": [ | ||
* { | ||
* "embedding": [ | ||
* 23 | ||
* ] | ||
* }, | ||
* { | ||
* "embedding": [ | ||
* -23 | ||
* ] | ||
* } | ||
* ] | ||
* } | ||
*/ | ||
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically a duplicate of @davidkyle let me know if this is what you were thinking. I suppose we could put off refactoring this until we support a third type. I think we'd still have to add in a bunch of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes |
||
public static final String NAME = "text_embedding_service_byte_results"; | ||
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); | ||
|
||
public TextEmbeddingByteResults(StreamInput in) throws IOException { | ||
this(in.readCollectionAsList(Embedding::new)); | ||
} | ||
|
||
@Override | ||
public int getFirstEmbeddingSize() { | ||
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings)); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startArray(TEXT_EMBEDDING); | ||
for (Embedding embedding : embeddings) { | ||
embedding.toXContent(builder, params); | ||
} | ||
builder.endArray(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeCollection(embeddings); | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return NAME; | ||
} | ||
|
||
@Override | ||
public List<? extends InferenceResults> transformToCoordinationFormat() { | ||
return embeddings.stream() | ||
.map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) | ||
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) | ||
.toList(); | ||
} | ||
|
||
@Override | ||
@SuppressWarnings("deprecation") | ||
public List<? extends InferenceResults> transformToLegacyFormat() { | ||
var legacyEmbedding = new LegacyTextEmbeddingResults( | ||
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList() | ||
); | ||
|
||
return List.of(legacyEmbedding); | ||
} | ||
|
||
public Map<String, Object> asMap() { | ||
Map<String, Object> map = new LinkedHashMap<>(); | ||
map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); | ||
|
||
return map; | ||
} | ||
|
||
public record Embedding(List<Byte> values) implements Writeable, ToXContentObject, EmbeddingInt { | ||
public static final String EMBEDDING = "embedding"; | ||
|
||
public Embedding(StreamInput in) throws IOException { | ||
this(in.readCollectionAsImmutableList(StreamInput::readByte)); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeCollection(values, StreamOutput::writeByte); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
|
||
builder.startArray(EMBEDDING); | ||
for (Byte value : values) { | ||
builder.value(value); | ||
} | ||
builder.endArray(); | ||
|
||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Strings.toString(this); | ||
} | ||
|
||
public Map<String, Object> asMap() { | ||
return Map.of(EMBEDDING, values); | ||
} | ||
|
||
public List<Float> toFloats() { | ||
return values.stream().map(Byte::floatValue).toList(); | ||
} | ||
|
||
@Override | ||
public int getSize() { | ||
return values().size(); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.inference.results; | ||
|
||
import java.util.List; | ||
|
||
public class TextEmbeddingUtils { | ||
|
||
/** | ||
* Returns the first text embedding entry's array size. | ||
* @param embeddings the list of embeddings | ||
* @return the size of the text embedding | ||
* @throws IllegalStateException if the list of embeddings is empty | ||
*/ | ||
public static int getFirstEmbeddingSize(List<EmbeddingInt> embeddings) throws IllegalStateException { | ||
if (embeddings.isEmpty()) { | ||
throw new IllegalStateException("Embeddings list is empty"); | ||
} | ||
|
||
return embeddings.get(0).getSize(); | ||
} | ||
|
||
private TextEmbeddingUtils() {} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,10 +22,5 @@ | |
exports org.elasticsearch.xpack.inference.registry; | ||
exports org.elasticsearch.xpack.inference.rest; | ||
exports org.elasticsearch.xpack.inference.services; | ||
exports org.elasticsearch.xpack.inference.external.http.sender; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @davidkyle do we need these? I had added them a while ago but now I get all kinds of warnings saying that various classes aren't exported. If I remove them they go away 🤷♂️ . I don't think we'd want to export most of this stuff out of the module anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ++ we shouldn't need to export that package, thanks for tidying up |
||
exports org.elasticsearch.xpack.inference.external.http; | ||
exports org.elasticsearch.xpack.inference.services.elser; | ||
exports org.elasticsearch.xpack.inference.services.huggingface.elser; | ||
exports org.elasticsearch.xpack.inference.services.openai; | ||
exports org.elasticsearch.xpack.inference; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to remove these because the cohere task settings need to write it as an optional enum so I figured it'd be better to just leave it up to the caller to handle how it is written.