Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Passing input type through to cohere request #104781

Merged
merged 8 commits into from
Jan 30, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REQUIRE_DATA_STREAM_ADDED = def(8_578_00_0);
public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_579_00_0);
public static final TransportVersion DESIRED_NODE_VERSION_OPTIONAL_STRING = def(8_580_00_0);
public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED = def(8_581_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ default void init(Client client) {}
* @param taskSettings Settings in the request to override the model's defaults
* @param listener Inference result listener
*/
void infer(Model model, List<String> input, Map<String, Object> taskSettings, ActionListener<InferenceServiceResults> listener);
void infer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ActionListener<InferenceServiceResults> listener
);

/**
* Start or prepare the model for use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
*/
public enum InputType {
INGEST,
SEARCH;

public static String NAME = "input_type";
SEARCH,
UNSPECIFIED;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to other names for this. Maybe UNKNOWN?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UNSPECIFIED is what it is, it's a good name


@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
Expand Down Expand Up @@ -96,7 +97,7 @@ public Request(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
this.inputType = in.readEnum(InputType.class);
} else {
this.inputType = InputType.INGEST;
this.inputType = InputType.UNSPECIFIED;
}
}

Expand Down Expand Up @@ -146,11 +147,22 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(input.get(0));
}
out.writeGenericMap(taskSettings);
// in version ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED the input type enum was added, so we only want to write the enum if we're
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double check this code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// at that version or later
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
out.writeEnum(inputType);
out.writeEnum(getInputTypeToWrite(out.getTransportVersion()));
}
}

private InputType getInputTypeToWrite(TransportVersion version) {
// in version ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED the UNSPECIFIED value was added, so if we're before that
// version other nodes won't know about it, so set it to INGEST instead
if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) && inputType == InputType.UNSPECIFIED) {
return InputType.INGEST;
}
return inputType;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down Expand Up @@ -203,7 +215,7 @@ public Builder setTaskSettings(Map<String, Object> taskSettings) {
}

public Request build() {
return new Request(taskType, inferenceEntityId, input, taskSettings, InputType.INGEST);
return new Request(taskType, inferenceEntityId, input, taskSettings, InputType.UNSPECIFIED);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we directly use UNSPECIFIED in this constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It'd probably be clearer if we add a setter on the builder. Currently the builder is only used in the rest code path. We don't expose the input type parameter to rest requests (only internal requests) so it'll always be unspecified.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -123,11 +124,11 @@ public void infer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ActionListener<InferenceServiceResults> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY -> listener.onResponse(makeResults(input));
case SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private void inferOnService(
model,
request.getInput(),
request.getTaskSettings(),
request.getInputType(),
listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.cohere;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand All @@ -28,8 +29,8 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
}

@Override
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = model.overrideWith(taskSettings);
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = model.overrideWith(taskSettings, inputType);

return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

package org.elasticsearch.xpack.inference.external.action.cohere;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;

import java.util.Map;

public interface CohereActionVisitor {
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings);
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;

public record CohereEmbeddingsRequestEntity(
List<String> input,
CohereEmbeddingsTaskSettings taskSettings,
Expand All @@ -29,14 +31,6 @@ public record CohereEmbeddingsRequestEntity(

private static final String SEARCH_DOCUMENT = "search_document";
private static final String SEARCH_QUERY = "search_query";
/**
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to go with a switch case here instead because that seems a little cleaner now that we have a value we don't want to allow.

* Maps the {@link InputType} to the expected value for cohere for the input_type field in the request using the enum's ordinal.
* The order of these entries is important and needs to match the order in the enum
*/
private static final String[] INPUT_TYPE_MAPPING = { SEARCH_DOCUMENT, SEARCH_QUERY };
static {
assert INPUT_TYPE_MAPPING.length == InputType.values().length : "input type mapping was incorrectly defined";
}

private static final String TEXTS_FIELD = "texts";

Expand All @@ -56,23 +50,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CohereServiceSettings.MODEL, model);
}

if (taskSettings.inputType() != null) {
builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType()));
if (taskSettings.getInputType() != null) {
builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.getInputType()));
}

if (embeddingType != null) {
builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType));
}

if (taskSettings.truncation() != null) {
builder.field(CohereServiceFields.TRUNCATE, taskSettings.truncation());
if (taskSettings.getTruncation() != null) {
builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation());
}

builder.endObject();
return builder;
}

private static String covertToString(InputType inputType) {
return INPUT_TYPE_MAPPING[inputType.ordinal()];
// default for testing
static String covertToString(InputType inputType) {
return switch (inputType) {
case INGEST -> SEARCH_DOCUMENT;
case SEARCH -> SEARCH_QUERY;
default -> {
assert false : invalidInputTypeMessage(inputType);
yield null;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
Expand Down Expand Up @@ -41,16 +42,23 @@ protected ServiceComponents getServiceComponents() {
}

@Override
public void infer(Model model, List<String> input, Map<String, Object> taskSettings, ActionListener<InferenceServiceResults> listener) {
public void infer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ActionListener<InferenceServiceResults> listener
) {
init();

doInfer(model, input, taskSettings, listener);
doInfer(model, input, taskSettings, inputType, listener);
}

protected abstract void doInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ActionListener<InferenceServiceResults> listener
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
Expand Down Expand Up @@ -110,7 +110,7 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}

public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) {
public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredTypes) {
return Strings.format(
"[%s] Invalid value [%s] received. [%s] must be one of [%s]",
scope,
Expand Down Expand Up @@ -225,25 +225,43 @@ public static <T> T extractOptionalEnum(
Map<String, Object> map,
String settingName,
String scope,
CheckedFunction<String, T, IllegalArgumentException> converter,
T[] validTypes,
EnumConstructor<T> constructor,
T[] validValues,
ValidationException validationException
) {
var enumString = extractOptionalString(map, settingName, scope, validationException);
if (enumString == null) {
return null;
}

var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
var validValuesAsStrings = Arrays.stream(validValues).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
try {
return converter.apply(enumString);
var createdEnum = constructor.apply(enumString);
validateEnumValue(createdEnum, validValues);

return createdEnum;
} catch (IllegalArgumentException e) {
validationException.addValidationError(invalidValue(settingName, scope, enumString, validTypesAsStrings));
validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings));
}

return null;
}

private static <T> void validateEnumValue(T enumValue, T[] validValues) {
if (Arrays.asList(validValues).contains(enumValue) == false) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you achieve the same by passing an EnumSet<T> instead of T[] validValues? It would make the lookup easier.

Nit: in the docs a generic enum parameter is referred to as E

https://docs.oracle.com/javase/8/docs/api/java/util/EnumSet.html

throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString()));
}
}

/**
* Functional interface for creating an enum from a string.
* @param <T>
*/
@FunctionalInterface
public interface EnumConstructor<T> {
T apply(String name) throws IllegalArgumentException;
}

public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
return format(
"Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
Expand Down Expand Up @@ -272,7 +290,7 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod
public static void getEmbeddingSize(Model model, InferenceService service, ActionListener<Integer> listener) {
assert model.getTaskType() == TaskType.TEXT_EMBEDDING;

service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), listener.delegateFailureAndWrap((delegate, r) -> {
service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), InputType.INGEST, listener.delegateFailureAndWrap((delegate, r) -> {
if (r instanceof TextEmbedding embeddingResults) {
try {
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.services.cohere;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand All @@ -30,5 +31,5 @@ protected CohereModel(CohereModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
}

public abstract ExecutableAction accept(CohereActionVisitor creator, Map<String, Object> taskSettings);
public abstract ExecutableAction accept(CohereActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -123,6 +124,7 @@ public void doInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof CohereModel == false) {
Expand All @@ -133,7 +135,7 @@ public void doInfer(
CohereModel cohereModel = (CohereModel) model;
var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());

var action = cohereModel.accept(actionCreator, taskSettings);
var action = cohereModel.accept(actionCreator, taskSettings, inputType);
action.execute(input, listener);
}

Expand Down Expand Up @@ -174,6 +176,6 @@ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsMo

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED;
return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping the cohere service min version just in case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what the implication of increasing the minimal supported version of a class is. Do you know what happens when we change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it means that if a cluster has nodes with different versions and one of the nodes is on ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED and a user attempts to create the cohere entity, it will fail until all the nodes are on the version ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED. Essentially users will have to wait until an upgrade process is complete before creating this inference entity.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.cohere.embeddings;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -73,16 +74,12 @@ public DefaultSecretSettings getSecretSettings() {
}

@Override
public ExecutableAction accept(CohereActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
public ExecutableAction accept(CohereActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
return visitor.create(this, taskSettings, inputType);
}

public CohereEmbeddingsModel overrideWith(Map<String, Object> taskSettings) {
if (taskSettings == null || taskSettings.isEmpty()) {
return this;
}

public CohereEmbeddingsModel overrideWith(Map<String, Object> taskSettings, InputType inputType) {
var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings);
return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings));
return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings).setInputType(inputType));
}
}
Loading