Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding support for Cohere inference service #104559

Merged
merged 17 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/104559.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 104559
summary: Adding support for Cohere inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_MULTI_CLUSTERS_ENRICH = def(8_576_00_0);
public static final TransportVersion NESTED_KNN_MORE_INNER_HITS = def(8_577_00_0);
public static final TransportVersion REQUIRE_DATA_STREAM_ADDED = def(8_578_00_0);
public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_579_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
16 changes: 3 additions & 13 deletions server/src/main/java/org/elasticsearch/inference/InputType.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@

package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Locale;

/**
* Defines the type of request, whether the request is to ingest a document or search for a document.
*/
public enum InputType implements Writeable {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to remove these because the cohere task settings need to write it as an optional enum so I figured it'd be better to just leave it up to the caller to handle how it is written.

public enum InputType {
INGEST,
SEARCH;

Expand All @@ -29,12 +24,7 @@ public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public static InputType fromStream(StreamInput in) throws IOException {
return in.readEnum(InputType.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
public static InputType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
}
14 changes: 14 additions & 0 deletions server/src/main/java/org/elasticsearch/inference/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) {
this.secrets = Objects.requireNonNull(secrets);
}

public Model(Model model, TaskSettings taskSettings) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These just make it easier to construct the Model when these various settings change but the secrets do not.

Objects.requireNonNull(model);

configurations = ModelConfigurations.of(model, taskSettings);
secrets = model.getSecrets();
}

public Model(Model model, ServiceSettings serviceSettings) {
Objects.requireNonNull(model);

configurations = ModelConfigurations.of(model, serviceSettings);
secrets = model.getSecrets();
}

public Model(ModelConfigurations configurations) {
this(configurations, new ModelSecrets());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@ public class ModelConfigurations implements ToXContentObject, VersionedNamedWrit
public static final String TASK_SETTINGS = "task_settings";
private static final String NAME = "inference_model";

public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
Objects.requireNonNull(model);
Objects.requireNonNull(taskSettings);

return new ModelConfigurations(
model.getConfigurations().getModelId(),
model.getConfigurations().getTaskType(),
model.getConfigurations().getService(),
model.getServiceSettings(),
taskSettings
);
}

public static ModelConfigurations of(Model model, ServiceSettings serviceSettings) {
Objects.requireNonNull(model);
Objects.requireNonNull(serviceSettings);

return new ModelConfigurations(
model.getConfigurations().getModelId(),
model.getConfigurations().getTaskType(),
model.getConfigurations().getService(),
serviceSettings,
model.getTaskSettings()
);
}

private final String modelId;
private final TaskType taskType;
private final String service;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Request(StreamInput in) throws IOException {
}
this.taskSettings = in.readGenericMap();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
this.inputType = InputType.fromStream(in);
this.inputType = in.readEnum(InputType.class);
} else {
this.inputType = InputType.INGEST;
}
Expand Down Expand Up @@ -141,7 +141,7 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeGenericMap(taskSettings);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
inputType.writeTo(out);
out.writeEnum(inputType);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

public interface EmbeddingInt {
int getSize();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to other ideas here. I did this to make determining the embedding size easier within the ServiceUtils class. The other option would be a second instanceof check for TextEmbeddingByteResults.

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

public interface TextEmbedding {

/**
* Returns the first text embedding entry in the result list's array size.
* @return the size of the text embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
int getFirstEmbeddingSize() throws IllegalStateException;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needed for simplifying ServiceUtils

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding": [
* {
* "embedding": [
* 23
* ]
* },
* {
* "embedding": [
* -23
* ]
* }
* ]
* }
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically a duplicate of TextEmbeddingResults. I was trying to think of an easy way to remove the duplication but I think we'd have to push logic up into some base class but that doesn't fully work because of the constructors that take a stream TextEmbeddingByteResults(StreamInput in). I think we could remove some of this by making the inner Embedding class a named writeable too but I'm not sure that's much better than what I had before with FloatValue and ByteValue.

@davidkyle let me know if this is what you were thinking. I suppose we could put off refactoring this until we support a third type. I think we'd still have to add in a bunch of if-blocks though if we do go the named writeable route.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally these results classes are different types (byte vs float), it may be verbose but it's good to have the different classes

public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();

public TextEmbeddingByteResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(Embedding::new));
}

@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray(TEXT_EMBEDDING);
for (Embedding embedding : embeddings) {
embedding.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(embeddings);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));

return map;
}

public record Embedding(List<Byte> values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readByte));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(values, StreamOutput::writeByte);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (Byte value : values) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

public Map<String, Object> asMap() {
return Map.of(EMBEDDING, values);
}

public List<Float> toFloats() {
return values.stream().map(Byte::floatValue).toList();
}

@Override
public int getSize() {
return values().size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -40,7 +41,7 @@
* ]
* }
*/
public record TextEmbeddingResults(List<Embedding> embeddings) implements InferenceServiceResults {
public record TextEmbeddingResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();

Expand All @@ -58,6 +59,11 @@ public TextEmbeddingResults(StreamInput in) throws IOException {
);
}

@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray(TEXT_EMBEDDING);
Expand Down Expand Up @@ -103,13 +109,18 @@ public Map<String, Object> asMap() {
return map;
}

public record Embedding(List<Float> values) implements Writeable, ToXContentObject {
public record Embedding(List<Float> values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readFloat));
}

@Override
public int getSize() {
return values.size();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(values, StreamOutput::writeFloat);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import java.util.List;

public class TextEmbeddingUtils {

/**
* Returns the first text embedding entry's array size.
* @param embeddings the list of embeddings
* @return the size of the text embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
public static int getFirstEmbeddingSize(List<EmbeddingInt> embeddings) throws IllegalStateException {
if (embeddings.isEmpty()) {
throw new IllegalStateException("Embeddings list is empty");
}

return embeddings.get(0).getSize();
}

private TextEmbeddingUtils() {}

}
5 changes: 0 additions & 5 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,5 @@
exports org.elasticsearch.xpack.inference.registry;
exports org.elasticsearch.xpack.inference.rest;
exports org.elasticsearch.xpack.inference.services;
exports org.elasticsearch.xpack.inference.external.http.sender;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle do we need these? I had added them a while ago but now I get all kinds of warnings saying that various classes aren't exported. If I remove them they go away 🤷‍♂️ . I don't think we'd want to export most of this stuff out of the module anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ we shouldn't need to export that package, thanks for tidying up

exports org.elasticsearch.xpack.inference.external.http;
exports org.elasticsearch.xpack.inference.services.elser;
exports org.elasticsearch.xpack.inference.services.huggingface.elser;
exports org.elasticsearch.xpack.inference.services.openai;
exports org.elasticsearch.xpack.inference;
}
Loading