Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding support for Cohere inference service #104559

Merged
merged 17 commits into from
Jan 24, 2024

Conversation

jonathan-buttner
Copy link
Contributor

@jonathan-buttner jonathan-buttner commented Jan 18, 2024

This PR adds support for interacting with Cohere's rest APIs in the _inference API. This only adds support for text embedding.

Cohere text embedding docs: https://docs.cohere.com/reference/embed

Not included

  • No support for tokenization, Cohere does provide tokenization APIs so chunking could leverage that API. That functionality isn't included in this PR because we don't support chunking in the other services yet either
  • No truncation support in the Cohere service implementation. Their API provides truncation support so we'll leverage that instead

Notable differences from the Cohere API

  • Cohere allows providing a embedding_types field and users can provide multiple values (an array of strings) and the response of the request will include embeddings for all the ones requested. For example a response could have floats and signed byte, and unsigned byte embeddings. To keep with our simplified API we only allow a single embedding type in a response. So this PR adds embedding_type as a single string field instead of an array.
  • input_type, Cohere's accepted values are search_document (for ingest) and search_query (for searching). Our API supports ingest and search. I thought that was more intuitive but can easily change it to something else if preferred

Notable Changes

  • Registers a single CohereService that will handle all inference operations in the future
  • The TextEmbeddingResults class now supports bytes in addition to floats
  • Corresponding specification PR

Response formats

Floats

{
    "text_embedding": [
        {
            "embedding": [
                0.7519531,
                0.34326172
            ]
        }
    ]
}

Bytes

{
    "text_embedding": [
        {
            "embedding": [
               1,
                -5
            ]
        }
    ]
}

Testing

Add the service

PUT _inference/text_embedding/cohere
{
    "service": "cohere",
    "service_settings": {
        "api_key": "<api key>"
    },
    "task_settings": {
        "model": "embed-english-v3.0",
        "input_type": "ingest",
        "embedding_type": "float",
        "truncate": "end"
    }
}

Performan inference

POST _inference/text_embedding/cohere
{
    "input": "The food was delicious."
}

@jonathan-buttner jonathan-buttner added :ml Machine learning Team:ML Meta label for the ML team cloud-deploy Publish cloud docker image for Cloud-First-Testing v8.13.0 labels Jan 18, 2024
import java.util.Locale;

/**
* Defines the type of request, whether the request is to ingest a document or search for a document.
*/
public enum InputType implements Writeable {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to remove these because the cohere task settings need to write it as an optional enum so I figured it'd be better to just leave it up to the caller to handle how it is written.

@@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) {
this.secrets = Objects.requireNonNull(secrets);
}

public Model(Model model, TaskSettings taskSettings) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These just make it easier to construct the Model when these various settings change but the secrets do not.

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readFloat));
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double check that this is correct. This was added so we can support returning bytes or floats. I thought the best way to do that was to used a named writable but happy to change it to something else if there's a better solution. The writables that we support now are the FloatValue and ByteValue classes above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to elaborate a bit on why we need implementations of EmbeddingValue rather than directly using primitives or Numbers?

Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be completely wrong but this was my thinking 😅 :

In this constructor public Embedding(StreamInput in), I don't know if I need to read floats or bytes (since we're allowing either to occur now with the addition of cohere). So I can't simply do in.readByte() or in.readFloat(). So I think the way to "know" is to use a named writable which led me to add the EmbeddingValue and the ByteValue, FloatValue. Because the in.readNamedWriteable() will figure out what type it is and read the appropriate one. That way when the results get serialized and another node needs to read them, it can determine whether to read them as floats or bytes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than directly using primitives or Numbers

It does seem excessive to hold an array of EmbeddingValue objects when a primitive would do the job. TextEmbeddingResults is a named writable one option is create the TextEmbeddingByteResults and ByteEmbeddings classes for the byte variant. That would avoid the conditionals in the streaming functions.

if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class);
} else {
values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually want nodes converting bytes into floats when sending to other nodes. I believe we'll prevent that by using the cohere service's minimum support version. So I don't think we'll ever need to execute this code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't actually want to run this code, would it be better to log the values and not actually do the conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question. Dave and I ran into this with this class as well: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java#L215-L224

which should be similar in that a few of those if-blocks shouldn't really happen. We decided for completeness that we'd handle the cases. I'm happy to remove it if we think otherwise though. It just gets complicated with the nodes on separate versions during upgrades. It should be possible but I suppose if we have a bug in the future we could get in a case where it was 🤷‍♂️

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at least we should be adding in an assertion here so that any tests that hit this code path fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Dave's suggestion I think we can get away with the transport version check all together. It'll add duplicate code though. I'll add a comment on the new class.

@@ -22,10 +22,5 @@
exports org.elasticsearch.xpack.inference.registry;
exports org.elasticsearch.xpack.inference.rest;
exports org.elasticsearch.xpack.inference.services;
exports org.elasticsearch.xpack.inference.external.http.sender;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle do we need these? I had added them a while ago but now I get all kinds of warnings saying that various classes aren't exported. If I remove them they go away 🤷‍♂️ . I don't think we'd want to export most of this stuff out of the module anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ we shouldn't need to export that package, thanks for tidying up

private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
var parsedByte = parser.shortValue();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see a byteValue() method. Should I add one to the parsing library? That'd probably be a follow up PR.

* 2.0.
*/

package org.elasticsearch.xpack.inference.services.cohere;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lives here instead of embeddings because I think a few other cohere APIs support this field as a body parameter

@@ -100,14 +100,6 @@ public OpenAiServiceSettings(
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens);
}

private static URI createOptionalUri(String url) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to ServiceUtils.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great self-reviewing!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great self-reviewing!

This ^. Thanks

@@ -46,7 +46,7 @@ protected Writeable.Reader<InferenceAction.Response> instanceReader() {
@Override
protected InferenceAction.Response createTestInstance() {
var result = switch (randomIntBetween(0, 2)) {
case 0 -> TextEmbeddingResultsTests.createRandomResults();
case 0 -> TextEmbeddingResultsTests.createRandomFloatResults();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle just a heads up. I was getting some bwc failures which I think are because of the TextEmbeddingResults conversion to namedwritables. So I'm forcing this to only use floats. Maybe there's an error in the way I've written that logic?

var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);

var httpRequest = request.createRequest();
MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get warnings when using assertThat directly now. It says it's deprecated and to use MatcherAssert.assertThat instead.

@jonathan-buttner jonathan-buttner marked this pull request as ready for review January 22, 2024 13:35
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@elasticsearchmachine
Copy link
Collaborator

Hi @jonathan-buttner, I've created a changelog YAML for you.

@jonathan-buttner
Copy link
Contributor Author

I forgot to pass through the InputType from the TransportCoordinatedInferenceAction. I believe that's going to require a change to the InferenceService interface. I'll stack a new PR on this one for that change.

Copy link
Member

@maxhniebergall maxhniebergall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

What a huge new feature! Amazing how many new bits were needed for this, and how many new tests there are.

I added some comments where I got confused or think there could be some improvement, but overall this seems great.

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readFloat));
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to elaborate a bit on why we need implementations of EmbeddingValue rather than directly using primitives or Numbers?

if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class);
} else {
values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't actually want to run this code, would it be better to log the values and not actually do the conversion?

@@ -35,5 +39,13 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e);
}

public static String getErrorMessage(@Nullable URI uri, String message) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps could rename this to constructFailedToSendRequestMessage or constructMessageForFailedToSendRequest

Comment on lines 32 to 34
throw new ElasticsearchStatusException(
Strings.format("Failed to construct %s URL", service),
RestStatus.INTERNAL_SERVER_ERROR,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we don't log the URL here, but only the service name?

Also, should this be an internal server error, or is it more like an illegal argument exception or an illegal state exception? Are we expecting that the service URL should have already been validated by this point, and ever getting to this point in the code is the exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we don't log the URL here, but only the service name?

Yeah we probably should just to be safe.

Are we expecting that the service URL should have already been validated by this point, and ever getting to this point in the code is the exception?

Yeah, if we actually get an exception I would expect it to be a programmatic error. It would be generated by a method like this one:

    static URI buildDefaultUri() throws URISyntaxException {
        return new URIBuilder().setScheme("https")
            .setHost(CohereUtils.HOST)
            .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH)
            .build();
    }

Which essentially just uses static strings and calls the apache library. There's no user input. So unless one of those strings is like ^^ which we hardcode in, we really shouldn't get an exception.

Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, since we're passing in the URISyntaxException into the ElasticsearchStatusException. I believe it'll generate a caused_by section that gets returned which would include the invalid URL.

Here's an example:


{
    "error": {
        "root_cause": [
            {
                "type": "status_exception",
                "reason": "Failed to construct Cohere URL"
            }
        ],
        "type": "status_exception",
        "reason": "Failed to construct Cohere URL",
        "caused_by": {
            "type": "u_r_i_syntax_exception",
            "reason": "Illegal character in authority at index 8: https://^^/v1/embed"
        }
    },
    "status": 500
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The URI may be private. It definitly should not be logged but it is ok to include in the REST response as the URI is known to the user who made the request in the first place. The URISyntaxException provides enough information for the user to debug.

Serverless will log responses with 500 status codes which is a good reason to avoid a 500 here. The code building the URI is good, it will only throw on bad inputs so this can be a BAD_REQUEST status.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked about this more offline. I'll switch this one to a bad request and then create a follow up PR to remove places where we log the URL and try to think of some other information we could log to help us with triaging errors. Maybe we could hash the url or something and log that instead.

@@ -100,14 +100,6 @@ public OpenAiServiceSettings(
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens);
}

private static URI createOptionalUri(String url) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great self-reviewing!

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

Left some minor comments, I'll take another pass later as it is a big PR

@@ -22,10 +22,5 @@
exports org.elasticsearch.xpack.inference.registry;
exports org.elasticsearch.xpack.inference.rest;
exports org.elasticsearch.xpack.inference.services;
exports org.elasticsearch.xpack.inference.external.http.sender;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ we shouldn't need to export that package, thanks for tidying up

Comment on lines 32 to 34
throw new ElasticsearchStatusException(
Strings.format("Failed to construct %s URL", service),
RestStatus.INTERNAL_SERVER_ERROR,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The URI may be private. It definitly should not be logged but it is ok to include in the REST response as the URI is known to the user who made the request in the first place. The URISyntaxException provides enough information for the user to debug.

Serverless will log responses with 500 status codes which is a good reason to avoid a 500 here. The code building the URI is good, it will only throw on bad inputs so this can be a BAD_REQUEST status.

/**
* Maps the {@link InputType} to the expected value for cohere for the input_type field in the request
*/
private static final Map<InputType, String> INPUT_TYPE_MAPPING = Map.of(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can avoid the map by using the ordinal value of the enum as the index to an array

private static final String [] INPUT_TYPE_MAPPING  = {SEARCH_DOCUMENT, SEARCH_QUERY}

// lookup
INPUT_TYPE_MAPPING[inputType.ordinal()]

Add an assertion somewhere that the arrays are the same size

assert INPUT_TYPE_MAPPING.length = InputType.values().length

* @param truncation Specifies how the API will handle inputs longer than the maximum token length
*/
public record CohereEmbeddingsTaskSettings(
@Nullable String model,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if model belongs in Task Settings or Service Settings.

The idea of task settings is that they can be overridden per request, but changing the model would produce incompatible results. The same argument can be applied to embedding type as the results would be incompatible with previously indexed results.

OpenAiEmbeddingsTaskSettings has set the precedent for model name in task settings but it is worth reviewing the idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. I'll move them to the service settings specific to text embeddings. If we wanted to move them out of the task settings for openai, what would be the process there? Maybe the easiest thing to do would be leave them in the task settings but also add them to the service settings and let the service settings take priority?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave them in the task settings but also add them to the service settings and let the service settings take priority?

It might be possible to automatically update the config to do this

@@ -100,14 +100,6 @@ public OpenAiServiceSettings(
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens);
}

private static URI createOptionalUri(String url) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great self-reviewing!

This ^. Thanks

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readFloat));
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than directly using primitives or Numbers

It does seem excessive to hold an array of EmbeddingValue objects when a primitive would do the job. TextEmbeddingResults is a named writable one option is create the TextEmbeddingByteResults and ByteEmbeddings classes for the byte variant. That would avoid the conditionals in the streaming functions.

public interface EmbeddingValue extends NamedWriteable, ToXContentFragment {
Number getValue();
public interface EmbeddingInt {
int getSize();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to other ideas here. I did this to make determining the embedding size easier within the ServiceUtils class. The other option would be a second instanceof check for TextEmbeddingByteResults.

* @return the size of the text embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
int getFirstEmbeddingSize() throws IllegalStateException;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needed for simplifying ServiceUtils

* ]
* }
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically a duplicate of TextEmbeddingResults. I was trying to think of an easy way to remove the duplication but I think we'd have to push logic up into some base class but that doesn't fully work because of the constructors that take a stream TextEmbeddingByteResults(StreamInput in). I think we could remove some of this by making the inner Embedding class a named writeable too but I'm not sure that's much better than what I had before with FloatValue and ByteValue.

@davidkyle let me know if this is what you were thinking. I suppose we could put off refactoring this until we support a third type. I think we'd still have to add in a bunch of if-blocks though if we do go the named writeable route.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes

@@ -98,6 +98,16 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
namedWriteables.add(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm keeping CohereServiceSettings it's own class because I think we'll use it again for other inference actions for cohere. The way I structured CohereEmbeddingsServiceSettings was to make the common settings be in CohereServiceSettings (e.g. the model name, dimensions etc) and put the embedding_type field within CohereEmbeddingsServiceSettings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in CohereEmbeddingsServiceSettings about a different way to serialise CohereServiceSettings that avoids adding this entry.

@@ -108,25 +110,6 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}

public static String mustBeNonEmptyList(String settingName, String scope) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had added these but they weren't used so removing them.

@@ -52,9 +52,7 @@ public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) {

public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (parsedUrl == null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convertToUri will just return null if passed null now.

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

* ]
* }
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes

}

public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException {
commonSettings = in.readNamedWriteable(CohereServiceSettings.class);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
commonSettings = in.readNamedWriteable(CohereServiceSettings.class);
commonSettings = new CohereServiceSettings(in);

And the writeTo() is changed to commonSettings.writeTo(out) then I think you can avoid added the CohereServiceSettings.class entry to the named writeable registry.

@@ -98,6 +98,16 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
namedWriteables.add(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in CohereEmbeddingsServiceSettings about a different way to serialise CohereServiceSettings that avoids adding this entry.

@jonathan-buttner jonathan-buttner merged commit eb8c73f into elastic:main Jan 24, 2024
16 checks passed
@jonathan-buttner jonathan-buttner deleted the ml-inference-cohere branch January 24, 2024 21:14
henningandersen pushed a commit to henningandersen/elasticsearch that referenced this pull request Jan 25, 2024
* Starting cohere

* Making progress on cohere

* Filling out the embedding types

* Working cohere

* Fixing tests

* Removing rate limit error message

* Fixing a few comments

* Update docs/changelog/104559.yaml

* Addressing most feedback

* Using separate named writeables for byte results and floats

* Fixing a few comments and adding tests

* Fixing mutation issue

* Removing cohere service settings from named writeable registry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cloud-deploy Publish cloud docker image for Cloud-First-Testing >enhancement :ml Machine learning Team:ML Meta label for the ML team v8.13.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants