-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding support for Cohere inference service #104559
[ML] Adding support for Cohere inference service #104559
Conversation
import java.util.Locale; | ||
|
||
/** | ||
* Defines the type of request, whether the request is to ingest a document or search for a document. | ||
*/ | ||
public enum InputType implements Writeable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to remove these because the cohere task settings need to write it as an optional enum so I figured it'd be better to just leave it up to the caller to handle how it is written.
@@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) { | |||
this.secrets = Objects.requireNonNull(secrets); | |||
} | |||
|
|||
public Model(Model model, TaskSettings taskSettings) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These just make it easier to construct the Model
when these various settings change but the secrets do not.
public Embedding(StreamInput in) throws IOException { | ||
this(in.readCollectionAsImmutableList(StreamInput::readFloat)); | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double check that this is correct. This was added so we can support returning bytes
or floats
. I thought the best way to do that was to used a named writable but happy to change it to something else if there's a better solution. The writables that we support now are the FloatValue
and ByteValue
classes above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you able to elaborate a bit on why we need implementations of EmbeddingValue rather than directly using primitives or Numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could be completely wrong but this was my thinking 😅 :
In this constructor public Embedding(StreamInput in)
, I don't know if I need to read floats or bytes (since we're allowing either to occur now with the addition of cohere). So I can't simply do in.readByte()
or in.readFloat()
. So I think the way to "know" is to use a named writable which led me to add the EmbeddingValue
and the ByteValue
, FloatValue
. Because the in.readNamedWriteable()
will figure out what type it is and read the appropriate one. That way when the results get serialized and another node needs to read them, it can determine whether to read them as floats or bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather than directly using primitives or Numbers
It does seem excessive to hold an array of EmbeddingValue objects when a primitive would do the job. TextEmbeddingResults
is a named writable one option is create the TextEmbeddingByteResults
and ByteEmbeddings
classes for the byte variant. That would avoid the conditionals in the streaming functions.
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { | ||
values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class); | ||
} else { | ||
values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we actually want nodes converting bytes into floats when sending to other nodes. I believe we'll prevent that by using the cohere service's minimum support version. So I don't think we'll ever need to execute this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't actually want to run this code, would it be better to log the values and not actually do the conversion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question. Dave and I ran into this with this class as well: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java#L215-L224
which should be similar in that a few of those if-blocks shouldn't really happen. We decided for completeness that we'd handle the cases. I'm happy to remove it if we think otherwise though. It just gets complicated with the nodes on separate versions during upgrades. It should be possible but I suppose if we have a bug in the future we could get in a case where it was 🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at least we should be adding in an assertion here so that any tests that hit this code path fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With Dave's suggestion I think we can get away with the transport version check all together. It'll add duplicate code though. I'll add a comment on the new class.
@@ -22,10 +22,5 @@ | |||
exports org.elasticsearch.xpack.inference.registry; | |||
exports org.elasticsearch.xpack.inference.rest; | |||
exports org.elasticsearch.xpack.inference.services; | |||
exports org.elasticsearch.xpack.inference.external.http.sender; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidkyle do we need these? I had added them a while ago but now I get all kinds of warnings saying that various classes aren't exported. If I remove them they go away 🤷♂️ . I don't think we'd want to export most of this stuff out of the module anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ we shouldn't need to export that package, thanks for tidying up
private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException { | ||
XContentParser.Token token = parser.currentToken(); | ||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); | ||
var parsedByte = parser.shortValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see a byteValue()
method. Should I add one to the parsing library? That'd probably be a follow up PR.
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.services.cohere; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lives here instead of embeddings
because I think a few other cohere APIs support this field as a body parameter
@@ -100,14 +100,6 @@ public OpenAiServiceSettings( | |||
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens); | |||
} | |||
|
|||
private static URI createOptionalUri(String url) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to ServiceUtils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great self-reviewing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great self-reviewing!
This ^. Thanks
@@ -46,7 +46,7 @@ protected Writeable.Reader<InferenceAction.Response> instanceReader() { | |||
@Override | |||
protected InferenceAction.Response createTestInstance() { | |||
var result = switch (randomIntBetween(0, 2)) { | |||
case 0 -> TextEmbeddingResultsTests.createRandomResults(); | |||
case 0 -> TextEmbeddingResultsTests.createRandomFloatResults(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidkyle just a heads up. I was getting some bwc failures which I think are because of the TextEmbeddingResults
conversion to namedwritables. So I'm forcing this to only use floats. Maybe there's an error in the way I've written that logic?
var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); | ||
|
||
var httpRequest = request.createRequest(); | ||
MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get warnings when using assertThat
directly now. It says it's deprecated and to use MatcherAssert.assertThat
instead.
Pinging @elastic/ml-core (Team:ML) |
Hi @jonathan-buttner, I've created a changelog YAML for you. |
I forgot to pass through the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What a huge new feature! Amazing how many new bits were needed for this, and how many new tests there are.
I added some comments where I got confused or think there could be some improvement, but overall this seems great.
public Embedding(StreamInput in) throws IOException { | ||
this(in.readCollectionAsImmutableList(StreamInput::readFloat)); | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you able to elaborate a bit on why we need implementations of EmbeddingValue rather than directly using primitives or Numbers?
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { | ||
values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class); | ||
} else { | ||
values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't actually want to run this code, would it be better to log the values and not actually do the conversion?
...in/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
Outdated
Show resolved
Hide resolved
@@ -35,5 +39,13 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e | |||
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e); | |||
} | |||
|
|||
public static String getErrorMessage(@Nullable URI uri, String message) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps could rename this to constructFailedToSendRequestMessage
or constructMessageForFailedToSendRequest
throw new ElasticsearchStatusException( | ||
Strings.format("Failed to construct %s URL", service), | ||
RestStatus.INTERNAL_SERVER_ERROR, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we don't log the URL here, but only the service name?
Also, should this be an internal server error, or is it more like an illegal argument exception or an illegal state exception? Are we expecting that the service URL should have already been validated by this point, and ever getting to this point in the code is the exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we don't log the URL here, but only the service name?
Yeah we probably should just to be safe.
Are we expecting that the service URL should have already been validated by this point, and ever getting to this point in the code is the exception?
Yeah, if we actually get an exception I would expect it to be a programmatic error. It would be generated by a method like this one:
static URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(CohereUtils.HOST)
.setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH)
.build();
}
Which essentially just uses static strings and calls the apache library. There's no user input. So unless one of those strings is like ^^
which we hardcode in, we really shouldn't get an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, since we're passing in the URISyntaxException
into the ElasticsearchStatusException
. I believe it'll generate a caused_by
section that gets returned which would include the invalid URL.
Here's an example:
{
"error": {
"root_cause": [
{
"type": "status_exception",
"reason": "Failed to construct Cohere URL"
}
],
"type": "status_exception",
"reason": "Failed to construct Cohere URL",
"caused_by": {
"type": "u_r_i_syntax_exception",
"reason": "Illegal character in authority at index 8: https://^^/v1/embed"
}
},
"status": 500
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The URI may be private. It definitly should not be logged but it is ok to include in the REST response as the URI is known to the user who made the request in the first place. The URISyntaxException
provides enough information for the user to debug.
Serverless will log responses with 500 status codes which is a good reason to avoid a 500 here. The code building the URI is good, it will only throw on bad inputs so this can be a BAD_REQUEST status.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We talked about this more offline. I'll switch this one to a bad request and then create a follow up PR to remove places where we log the URL and try to think of some other information we could log to help us with triaging errors. Maybe we could hash the url or something and log that instead.
...java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
Show resolved
Hide resolved
@@ -100,14 +100,6 @@ public OpenAiServiceSettings( | |||
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens); | |||
} | |||
|
|||
private static URI createOptionalUri(String url) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great self-reviewing!
...rence/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
Outdated
Show resolved
Hide resolved
...sticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
Show resolved
Hide resolved
...ence/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Left some minor comments, I'll take another pass later as it is a big PR
@@ -22,10 +22,5 @@ | |||
exports org.elasticsearch.xpack.inference.registry; | |||
exports org.elasticsearch.xpack.inference.rest; | |||
exports org.elasticsearch.xpack.inference.services; | |||
exports org.elasticsearch.xpack.inference.external.http.sender; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ we shouldn't need to export that package, thanks for tidying up
throw new ElasticsearchStatusException( | ||
Strings.format("Failed to construct %s URL", service), | ||
RestStatus.INTERNAL_SERVER_ERROR, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The URI may be private. It definitly should not be logged but it is ok to include in the REST response as the URI is known to the user who made the request in the first place. The URISyntaxException
provides enough information for the user to debug.
Serverless will log responses with 500 status codes which is a good reason to avoid a 500 here. The code building the URI is good, it will only throw on bad inputs so this can be a BAD_REQUEST status.
/** | ||
* Maps the {@link InputType} to the expected value for cohere for the input_type field in the request | ||
*/ | ||
private static final Map<InputType, String> INPUT_TYPE_MAPPING = Map.of( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can avoid the map by using the ordinal
value of the enum as the index to an array
private static final String [] INPUT_TYPE_MAPPING = {SEARCH_DOCUMENT, SEARCH_QUERY}
// lookup
INPUT_TYPE_MAPPING[inputType.ordinal()]
Add an assertion somewhere that the arrays are the same size
assert INPUT_TYPE_MAPPING.length = InputType.values().length
* @param truncation Specifies how the API will handle inputs longer than the maximum token length | ||
*/ | ||
public record CohereEmbeddingsTaskSettings( | ||
@Nullable String model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if model belongs in Task Settings or Service Settings.
The idea of task settings is that they can be overridden per request, but changing the model would produce incompatible results. The same argument can be applied to embedding type as the results would be incompatible with previously indexed results.
OpenAiEmbeddingsTaskSettings
has set the precedent for model name in task settings but it is worth reviewing the idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. I'll move them to the service settings specific to text embeddings. If we wanted to move them out of the task settings for openai, what would be the process there? Maybe the easiest thing to do would be leave them in the task settings but also add them to the service settings and let the service settings take priority?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leave them in the task settings but also add them to the service settings and let the service settings take priority?
It might be possible to automatically update the config to do this
@@ -100,14 +100,6 @@ public OpenAiServiceSettings( | |||
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens); | |||
} | |||
|
|||
private static URI createOptionalUri(String url) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great self-reviewing!
This ^. Thanks
public Embedding(StreamInput in) throws IOException { | ||
this(in.readCollectionAsImmutableList(StreamInput::readFloat)); | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather than directly using primitives or Numbers
It does seem excessive to hold an array of EmbeddingValue objects when a primitive would do the job. TextEmbeddingResults
is a named writable one option is create the TextEmbeddingByteResults
and ByteEmbeddings
classes for the byte variant. That would avoid the conditionals in the streaming functions.
public interface EmbeddingValue extends NamedWriteable, ToXContentFragment { | ||
Number getValue(); | ||
public interface EmbeddingInt { | ||
int getSize(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to other ideas here. I did this to make determining the embedding size easier within the ServiceUtils
class. The other option would be a second instanceof
check for TextEmbeddingByteResults
.
* @return the size of the text embedding | ||
* @throws IllegalStateException if the list of embeddings is empty | ||
*/ | ||
int getFirstEmbeddingSize() throws IllegalStateException; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also needed for simplifying ServiceUtils
* ] | ||
* } | ||
*/ | ||
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically a duplicate of TextEmbeddingResults
. I was trying to think of an easy way to remove the duplication but I think we'd have to push logic up into some base class but that doesn't fully work because of the constructors that take a stream TextEmbeddingByteResults(StreamInput in)
. I think we could remove some of this by making the inner Embedding
class a named writeable too but I'm not sure that's much better than what I had before with FloatValue
and ByteValue
.
@davidkyle let me know if this is what you were thinking. I suppose we could put off refactoring this until we support a third type. I think we'd still have to add in a bunch of if-blocks
though if we do go the named writeable route.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes
@@ -98,6 +98,16 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() { | |||
namedWriteables.add( | |||
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) | |||
); | |||
namedWriteables.add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm keeping CohereServiceSettings
it's own class because I think we'll use it again for other inference actions for cohere. The way I structured CohereEmbeddingsServiceSettings
was to make the common settings be in CohereServiceSettings
(e.g. the model name, dimensions etc) and put the embedding_type
field within CohereEmbeddingsServiceSettings
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in CohereEmbeddingsServiceSettings
about a different way to serialise CohereServiceSettings
that avoids adding this entry.
@@ -108,25 +110,6 @@ public static String mustBeNonEmptyString(String settingName, String scope) { | |||
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); | |||
} | |||
|
|||
public static String mustBeNonEmptyList(String settingName, String scope) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had added these but they weren't used so removing them.
@@ -52,9 +52,7 @@ public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) { | |||
|
|||
public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) { | |||
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); | |||
if (parsedUrl == null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convertToUri
will just return null if passed null now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* ] | ||
* } | ||
*/ | ||
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes
} | ||
|
||
public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException { | ||
commonSettings = in.readNamedWriteable(CohereServiceSettings.class); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commonSettings = in.readNamedWriteable(CohereServiceSettings.class); | |
commonSettings = new CohereServiceSettings(in); |
And the writeTo()
is changed to commonSettings.writeTo(out)
then I think you can avoid added the CohereServiceSettings.class
entry to the named writeable registry.
@@ -98,6 +98,16 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() { | |||
namedWriteables.add( | |||
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) | |||
); | |||
namedWriteables.add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in CohereEmbeddingsServiceSettings
about a different way to serialise CohereServiceSettings
that avoids adding this entry.
* Starting cohere * Making progress on cohere * Filling out the embedding types * Working cohere * Fixing tests * Removing rate limit error message * Fixing a few comments * Update docs/changelog/104559.yaml * Addressing most feedback * Using separate named writeables for byte results and floats * Fixing a few comments and adding tests * Fixing mutation issue * Removing cohere service settings from named writeable registry
This PR adds support for interacting with Cohere's rest APIs in the _inference API. This only adds support for text embedding.
Cohere text embedding docs: https://docs.cohere.com/reference/embed
Not included
Notable differences from the Cohere API
embedding_types
field and users can provide multiple values (an array of strings) and the response of the request will include embeddings for all the ones requested. For example a response could have floats and signed byte, and unsigned byte embeddings. To keep with our simplified API we only allow a single embedding type in a response. So this PR addsembedding_type
as a single string field instead of an array.input_type
, Cohere's accepted values aresearch_document
(for ingest) andsearch_query
(for searching). Our API supportsingest
andsearch
. I thought that was more intuitive but can easily change it to something else if preferredNotable Changes
TextEmbeddingResults
class now supportsbytes
in addition tofloats
Response formats
Floats
Bytes
Testing
Add the service
Performan inference