From ee9477dc56406295f9fd8b1e0dd84b6d05683d9e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 25 Jan 2024 17:06:18 -0500 Subject: [PATCH 1/6] Pushing input type through to cohere request --- .../org/elasticsearch/TransportVersions.java | 1 + .../inference/InferenceService.java | 8 +- .../elasticsearch/inference/InputType.java | 5 +- .../inference/action/InferenceAction.java | 18 +- .../action/TransportInferenceAction.java | 1 + .../action/cohere/CohereActionCreator.java | 5 +- .../action/cohere/CohereActionVisitor.java | 3 +- .../cohere/CohereEmbeddingsRequestEntity.java | 30 +-- .../inference/services/SenderService.java | 12 +- .../inference/services/ServiceUtils.java | 34 +++- .../services/cohere/CohereModel.java | 3 +- .../services/cohere/CohereService.java | 6 +- .../embeddings/CohereEmbeddingsModel.java | 13 +- .../CohereEmbeddingsTaskSettings.java | 78 +++++++- .../services/elser/ElserMlNodeService.java | 9 +- .../huggingface/HuggingFaceBaseService.java | 2 + .../services/openai/OpenAiService.java | 2 + .../xpack/inference/InputTypeTests.java | 21 ++ .../action/InferenceActionRequestTests.java | 82 +++++++- .../cohere/CohereActionCreatorTests.java | 2 +- .../CohereEmbeddingsRequestEntityTests.java | 5 + .../services/SenderServiceTests.java | 2 + .../inference/services/ServiceUtilsTests.java | 46 ++++- .../services/cohere/CohereServiceTests.java | 186 +++++++++++++++++- .../CohereEmbeddingsModelTests.java | 93 ++++++++- .../CohereEmbeddingsTaskSettingsTests.java | 48 ++++- .../HuggingFaceBaseServiceTests.java | 3 +- .../huggingface/HuggingFaceServiceTests.java | 5 +- .../services/openai/OpenAiServiceTests.java | 7 +- 29 files changed, 644 insertions(+), 86 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 7b3ca0e2f069a..01c286b877bf7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -191,6 +191,7 @@ static TransportVersion def(int id) { public static final TransportVersion NESTED_KNN_MORE_INNER_HITS = def(8_577_00_0); public static final TransportVersion REQUIRE_DATA_STREAM_ADDED = def(8_578_00_0); public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_579_00_0); + public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED = def(8_580_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 235de51d22572..fdeb32de33877 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -78,7 +78,13 @@ default void init(Client client) {} * @param taskSettings Settings in the request to override the model's defaults * @param listener Inference result listener */ - void infer(Model model, List input, Map taskSettings, ActionListener listener); + void infer( + Model model, + List input, + Map taskSettings, + InputType inputType, + ActionListener listener + ); /** * Start or prepare the model for use. diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index ffc67995c1dda..19f28601409ac 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -15,9 +15,8 @@ */ public enum InputType { INGEST, - SEARCH; - - public static String NAME = "input_type"; + SEARCH, + UNSPECIFIED; @Override public String toString() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 30375e36a0e1d..046f7ea3351c6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; @@ -90,7 +91,7 @@ public Request(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { this.inputType = in.readEnum(InputType.class); } else { - this.inputType = InputType.INGEST; + this.inputType = InputType.UNSPECIFIED; } } @@ -140,11 +141,22 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(input.get(0)); } out.writeGenericMap(taskSettings); + // in version ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED the input type enum was added, so we only want to write the enum if we're + // at that version or later if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { - out.writeEnum(inputType); + out.writeEnum(getInputTypeToWrite(out.getTransportVersion())); } } + private InputType getInputTypeToWrite(TransportVersion version) { + // in version ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED the UNSPECIFIED value was added, so if we're before that + // version other nodes won't know about it, so set it to INGEST instead + if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) && inputType == InputType.UNSPECIFIED) { + return InputType.INGEST; + } + return inputType; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -197,7 +209,7 @@ public Builder setTaskSettings(Map taskSettings) { } public Request build() { - return new Request(taskType, modelId, input, taskSettings, InputType.INGEST); + return new Request(taskType, modelId, input, taskSettings, InputType.UNSPECIFIED); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index db98aeccc556b..c5ea0be06af3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -92,6 +92,7 @@ private void inferOnService( model, request.getInput(), request.getTaskSettings(), + request.getInputType(), listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index 8c9d70f0a7323..e6b5ceb03f147 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.action.cohere; +import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -28,8 +29,8 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { } @Override - public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) { - var overriddenModel = model.overrideWith(taskSettings); + public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings, InputType inputType) { + var overriddenModel = model.overrideWith(taskSettings, inputType); return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java index 1500d48e3c201..cc732e7ab8dc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java @@ -7,11 +7,12 @@ package org.elasticsearch.xpack.inference.external.action.cohere; +import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import java.util.Map; public interface CohereActionVisitor { - ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings); + ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings, InputType inputType); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java index a0b5444ee45e4..9e34af5ed6385 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage; + public record CohereEmbeddingsRequestEntity( List input, CohereEmbeddingsTaskSettings taskSettings, @@ -29,14 +31,6 @@ public record CohereEmbeddingsRequestEntity( private static final String SEARCH_DOCUMENT = "search_document"; private static final String SEARCH_QUERY = "search_query"; - /** - * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request using the enum's ordinal. - * The order of these entries is important and needs to match the order in the enum - */ - private static final String[] INPUT_TYPE_MAPPING = { SEARCH_DOCUMENT, SEARCH_QUERY }; - static { - assert INPUT_TYPE_MAPPING.length == InputType.values().length : "input type mapping was incorrectly defined"; - } private static final String TEXTS_FIELD = "texts"; @@ -56,23 +50,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CohereServiceSettings.MODEL, model); } - if (taskSettings.inputType() != null) { - builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType())); + if (taskSettings.getInputType() != null) { + builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.getInputType())); } if (embeddingType != null) { builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType)); } - if (taskSettings.truncation() != null) { - builder.field(CohereServiceFields.TRUNCATE, taskSettings.truncation()); + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); } builder.endObject(); return builder; } - private static String covertToString(InputType inputType) { - return INPUT_TYPE_MAPPING[inputType.ordinal()]; + // default for testing + static String covertToString(InputType inputType) { + return switch (inputType) { + case INGEST -> SEARCH_DOCUMENT; + case SEARCH -> SEARCH_QUERY; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index bb45e8fd684a6..0c40863b37db2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -41,16 +42,23 @@ protected ServiceComponents getServiceComponents() { } @Override - public void infer(Model model, List input, Map taskSettings, ActionListener listener) { + public void infer( + Model model, + List input, + Map taskSettings, + InputType inputType, + ActionListener listener + ) { init(); - doInfer(model, input, taskSettings, listener); + doInfer(model, input, taskSettings, inputType, listener); } protected abstract void doInfer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 7029f9ca3bf56..fbded828d6f3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -11,10 +11,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -110,7 +110,7 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } - public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) { + public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredTypes) { return Strings.format( "[%s] Invalid value [%s] received. [%s] must be one of [%s]", scope, @@ -225,8 +225,8 @@ public static T extractOptionalEnum( Map map, String settingName, String scope, - CheckedFunction converter, - T[] validTypes, + EnumConstructor constructor, + T[] validValues, ValidationException validationException ) { var enumString = extractOptionalString(map, settingName, scope, validationException); @@ -234,16 +234,34 @@ public static T extractOptionalEnum( return null; } - var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + var validValuesAsStrings = Arrays.stream(validValues).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); try { - return converter.apply(enumString); + var createdEnum = constructor.apply(enumString); + validateEnumValue(createdEnum, validValues); + + return createdEnum; } catch (IllegalArgumentException e) { - validationException.addValidationError(invalidValue(settingName, scope, enumString, validTypesAsStrings)); + validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings)); } return null; } + private static void validateEnumValue(T enumValue, T[] validValues) { + if (Arrays.asList(validValues).contains(enumValue) == false) { + throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString())); + } + } + + /** + * Functional interface for creating an enum from a string. + * @param + */ + @FunctionalInterface + public interface EnumConstructor { + T apply(String name) throws IllegalArgumentException; + } + public static String parsePersistedConfigErrorMsg(String modelId, String serviceName) { return format("Failed to parse stored model [%s] for [%s] service, please delete and add the service again", modelId, serviceName); } @@ -268,7 +286,7 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod public static void getEmbeddingSize(Model model, InferenceService service, ActionListener listener) { assert model.getTaskType() == TaskType.TEXT_EMBEDDING; - service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), listener.delegateFailureAndWrap((delegate, r) -> { + service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), InputType.INGEST, listener.delegateFailureAndWrap((delegate, r) -> { if (r instanceof TextEmbedding embeddingResults) { try { delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index 1b4843e441248..81a27e1e536f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.cohere; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -30,5 +31,5 @@ protected CohereModel(CohereModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); } - public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings); + public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings, InputType inputType); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 8783f12852ec8..3f608c977f686 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -123,6 +124,7 @@ public void doInfer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ) { if (model instanceof CohereModel == false) { @@ -133,7 +135,7 @@ public void doInfer( CohereModel cohereModel = (CohereModel) model; var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); - var action = cohereModel.accept(actionCreator, taskSettings); + var action = cohereModel.accept(actionCreator, taskSettings, inputType); action.execute(input, listener); } @@ -174,6 +176,6 @@ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsMo @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; + return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index c92700e87cd96..b0ff6381696e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.cohere.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -73,16 +74,12 @@ public DefaultSecretSettings getSecretSettings() { } @Override - public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { - return visitor.create(this, taskSettings); + public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); } - public CohereEmbeddingsModel overrideWith(Map taskSettings) { - if (taskSettings == null || taskSettings.isEmpty()) { - return this; - } - + public CohereEmbeddingsModel overrideWith(Map taskSettings, InputType inputType) { var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings); - return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings)); + return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings).setIfAbsent(inputType)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 858efdb0d1ace..e48661f6ae1e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -20,7 +21,10 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE; @@ -31,18 +35,17 @@ *

* See api docs for details. *

- * - * @param inputType Specifies the type of input you're giving to the model - * @param truncation Specifies how the API will handle inputs longer than the maximum token length */ -public record CohereEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable CohereTruncation truncation) implements TaskSettings { +public class CohereEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "cohere_embeddings_task_settings"; public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null); static final String INPUT_TYPE = "input_type"; + private static final InputType[] VALID_REQUEST_VALUES = { InputType.INGEST, InputType.SEARCH }; + private static final List VALID_INPUT_VALUES_LIST = Arrays.asList(VALID_REQUEST_VALUES); public static CohereEmbeddingsTaskSettings fromMap(Map map) { - if (map.isEmpty()) { + if (map == null || map.isEmpty()) { return EMPTY_SETTINGS; } @@ -53,7 +56,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, InputType::fromString, - InputType.values(), + VALID_REQUEST_VALUES, validationException ); CohereTruncation truncation = extractOptionalEnum( @@ -72,10 +75,27 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { return new CohereEmbeddingsTaskSettings(inputType, truncation); } + private final InputType inputType; + private final CohereTruncation truncation; + public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { this(in.readOptionalEnum(InputType.class), in.readOptionalEnum(CohereTruncation.class)); } + public CohereEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable CohereTruncation truncation) { + validateInputType(inputType); + this.inputType = inputType; + this.truncation = truncation; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_INPUT_VALUES_LIST.contains(inputType) : invalidInputTypeMessage(inputType); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -90,6 +110,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + public InputType getInputType() { + return inputType; + } + + public CohereTruncation getTruncation() { + return truncation; + } + @Override public String getWriteableName() { return NAME; @@ -106,10 +134,44 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(truncation); } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CohereEmbeddingsTaskSettings that = (CohereEmbeddingsTaskSettings) o; + return inputType == that.inputType && truncation == that.truncation; + } + + @Override + public int hashCode() { + return Objects.hash(inputType, truncation); + } + public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { - var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType(); - var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation(); + if (requestTaskSettings.equals(EMPTY_SETTINGS)) { + return this; + } + + var inputTypeToUse = requestTaskSettings.getInputType() == null ? inputType : requestTaskSettings.getInputType(); + var truncationToUse = requestTaskSettings.getTruncation() == null ? truncation : requestTaskSettings.getTruncation(); return new CohereEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); } + + /** + * Sets the input type field for the task settings if it is currently null and the passed in input type value is valid. + * @param inputType the new input type to use + * @return task settings with the values set if they were previously null + */ + public CohereEmbeddingsTaskSettings setIfAbsent(InputType inputType) { + if (this.inputType != null || VALID_INPUT_VALUES_LIST.contains(inputType) == false) { + return this; + } + + return new CohereEmbeddingsTaskSettings(inputType, truncation); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 4755c11ece9fe..9371724021330 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; @@ -208,7 +209,13 @@ public void stop(String modelId, ActionListener listener) { } @Override - public void infer(Model model, List input, Map taskSettings, ActionListener listener) { + public void infer( + Model model, + List input, + Map taskSettings, + InputType inputType, + ActionListener listener + ) { // No task settings to override with requestTaskSettings if (TaskType.SPARSE_EMBEDDING.isAnyOrSame(model.getConfigurations().getTaskType()) == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index a7dc26b8472d1..4f31ab804a86e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -90,6 +91,7 @@ public void doInfer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ) { if (model instanceof HuggingFaceModel == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 1bdd1abce0b45..35b1d41cad356 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -122,6 +123,7 @@ public void doInfer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ) { if (model instanceof OpenAiModel == false) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java new file mode 100644 index 0000000000000..088f93507d35f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; + +public class InputTypeTests extends ESTestCase { + public static InputType randomWithoutUnspecified() { + return randomFrom(InputType.INGEST, InputType.SEARCH); + } + + public static InputType[] valuesWithoutUnspecified() { + return new InputType[] { InputType.INGEST, InputType.SEARCH }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index ee7bfc96c1370..146cff2b44276 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -7,22 +7,26 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.hamcrest.Matchers.is; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; -public class InferenceActionRequestTests extends AbstractWireSerializingTestCase { +public class InferenceActionRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Writeable.Reader instanceReader() { @@ -70,7 +74,7 @@ public void testParseRequest_DefaultsInputTypeToIngest() throws IOException { """; try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) { var request = InferenceAction.Request.parseRequest("model_id", "sparse_embedding", parser); - assertThat(request.getInputType(), is(InputType.INGEST)); + assertThat(request.getInputType(), is(InputType.UNSPECIFIED)); } } @@ -135,4 +139,76 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc default -> throw new UnsupportedOperationException(); }; } + + @Override + protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Request instance, TransportVersion version) { + if (version.before(TransportVersions.INFERENCE_MULTIPLE_INPUTS)) { + return new InferenceAction.Request( + instance.getTaskType(), + instance.getModelId(), + instance.getInput().subList(0, 1), + instance.getTaskSettings(), + InputType.UNSPECIFIED + ); + } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { + return new InferenceAction.Request( + instance.getTaskType(), + instance.getModelId(), + instance.getInput(), + instance.getTaskSettings(), + InputType.UNSPECIFIED + ); + } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) + && instance.getInputType() == InputType.UNSPECIFIED) { + return new InferenceAction.Request( + instance.getTaskType(), + instance.getModelId(), + instance.getInput(), + instance.getTaskSettings(), + InputType.INGEST + ); + } + + return instance; + } + + public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOException { + assertBwcSerialization( + new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.UNSPECIFIED), + TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED + ); + } + + public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest() throws IOException { + assertBwcSerialization( + new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.UNSPECIFIED), + TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED + ); + } + + public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest_ManualCheck() throws IOException { + var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.UNSPECIFIED); + + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED + ); + + assertThat(deserializedInstance.getInputType(), is(InputType.INGEST)); + } + + public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUnspecified() throws IOException { + var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.INGEST); + + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.HOT_THREADS_AS_BYTES + ); + + assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 67a95265f093d..e7cfc784db117 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -110,7 +110,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { ); var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END); - var action = actionCreator.create(model, overriddenTaskSettings); + var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java index 8ef9ea4b0316b..2d3ff25222ab9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java @@ -66,4 +66,9 @@ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws I MatcherAssert.assertThat(xContentResult, is(""" {"texts":["abc"]}""")); } + + public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> CohereEmbeddingsRequestEntity.covertToString(InputType.UNSPECIFIED)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index bae2e7e9b68c9..f2891a04fc6d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -105,6 +106,7 @@ protected void doInfer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index b935c5a8c64b3..dafe12f21bf3d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -271,7 +271,14 @@ public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() { public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidValueExists() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", "invalid_value")); - var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, InputType.values(), validation); + var createdEnum = extractOptionalEnum( + map, + "key", + "scope", + InputType::fromString, + new InputType[] { InputType.INGEST, InputType.SEARCH }, + validation + ); assertNull(createdEnum); assertFalse(validation.validationErrors().isEmpty()); @@ -282,6 +289,27 @@ public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidVal ); } + public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenValueIsNotPartOfTheAcceptableValues() { + var validation = new ValidationException(); + Map map = modifiableMap(Map.of("key", InputType.UNSPECIFIED.toString())); + var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, new InputType[] { InputType.INGEST }, validation); + + assertNull(createdEnum); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + assertThat(validation.validationErrors().get(0), is("[scope] Invalid value [unspecified] received. [key] must be one of [ingest]")); + } + + public void testExtractOptionalEnum_ReturnsIngest_WhenValueIsAcceptable() { + var validation = new ValidationException(); + Map map = modifiableMap(Map.of("key", InputType.INGEST.toString())); + var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, new InputType[] { InputType.INGEST }, validation); + + assertThat(createdEnum, is(InputType.INGEST)); + assertTrue(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() { var service = mock(InferenceService.class); @@ -290,11 +318,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[3]; + ActionListener listener = (ActionListener) invocation.getArguments()[4]; listener.onResponse(new TextEmbeddingResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -313,11 +341,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmp doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[3]; + ActionListener listener = (ActionListener) invocation.getArguments()[4]; listener.onResponse(new TextEmbeddingByteResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -338,11 +366,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[3]; + ActionListener listener = (ActionListener) invocation.getArguments()[4]; listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -362,11 +390,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[3]; + ActionListener listener = (ActionListener) invocation.getArguments()[4]; listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 0250e08a48452..25d5dd74d733a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -686,7 +686,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException try (var service = new CohereService(new SetOnce<>(factory), new SetOnce<>(createWithEmptySettings(threadPool)))) { PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(mockModel, List.of(""), new HashMap<>(), listener); + service.infer(mockModel, List.of(""), new HashMap<>(), InputType.INGEST, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); MatcherAssert.assertThat( @@ -745,7 +745,7 @@ public void testInfer_SendsRequest() throws IOException { null ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var result = listener.actionGet(TIMEOUT); @@ -848,7 +848,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { null ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); @@ -857,6 +857,186 @@ public void testInfer_UnauthorisedResponse() throws IOException { } } + public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "model", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document"))); + } + } + + public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenPresentInTaskSettingsAsSearch() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), + 1024, + 1024, + "model", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_query"))); + } + } + + public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(null, null), + 1024, + 1024, + "model", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.UNSPECIFIED, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model"))); + } + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index 1961d6b168d54..943a000b0c3aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -21,11 +21,10 @@ import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; public class CohereEmbeddingsModelTests extends ESTestCase { - public void testOverrideWith_OverridesInputType_WithSearch() { + public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_NotRequestIngest() { var model = createModel( "url", "api_key", @@ -36,7 +35,9 @@ public void testOverrideWith_OverridesInputType_WithSearch() { CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null)); + // if the request task settings specify an input type value we should honor that instead of the input type that comes from the + // request itself + var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null), InputType.INGEST); var expectedModel = createModel( "url", "api_key", @@ -49,18 +50,92 @@ public void testOverrideWith_OverridesInputType_WithSearch() { MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } - public void testOverrideWith_DoesNotOverride_WhenSettingsAreEmpty() { + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty() { var model = createModel("url", "api_key", null, null, null); - var overriddenModel = model.overrideWith(Map.of()); - MatcherAssert.assertThat(overriddenModel, sameInstance(model)); + var overriddenModel = model.overrideWith(Map.of(), InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); } - public void testOverrideWith_DoesNotOverride_WhenSettingsAreNull() { + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { var model = createModel("url", "api_key", null, null, null); - var overriddenModel = model.overrideWith(null); - MatcherAssert.assertThat(overriddenModel, sameInstance(model)); + var overriddenModel = model.overrideWith(null, InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(null, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + + // since the stored model and request task settings does not have input type defined we'll get it from the request + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.INGEST); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_FromRequest_IfStoredModelHasInputTypeSet() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + + // since the stored model has the input type field set, we will not set it with the request input type + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.SEARCH); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(null, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.UNSPECIFIED); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(null, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public static CohereEmbeddingsModel createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index 164d3998f138f..120fd4556d1aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -15,18 +15,20 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; import static org.hamcrest.Matchers.is; public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { public static CohereEmbeddingsTaskSettings createRandom() { - var inputType = randomBoolean() ? randomFrom(InputType.values()) : null; + var inputType = randomBoolean() ? randomWithoutUnspecified() : null; var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; return new CohereEmbeddingsTaskSettings(inputType, truncation); @@ -39,6 +41,10 @@ public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { ); } + public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { + MatcherAssert.assertThat(CohereEmbeddingsTaskSettings.fromMap(null), is(new CohereEmbeddingsTaskSettings(null, null))); + } + public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { MatcherAssert.assertThat( CohereEmbeddingsTaskSettings.fromMap( @@ -67,6 +73,25 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { ); } + public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString())) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is("Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [ingest, search];") + ); + } + + public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> new CohereEmbeddingsTaskSettings(InputType.UNSPECIFIED, null)); + MatcherAssert.assertThat(thrownException.getMessage(), CoreMatchers.is("received invalid input type value [unspecified]")); + } + public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { var taskSettings = CohereEmbeddingsTaskSettings.fromMap( new HashMap<>(Map.of(CohereServiceSettings.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) @@ -89,6 +114,27 @@ public void testOverrideWith_UsesOverriddenSettings() { MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.START))); } + public void testSetIfAbsent_DoesNotSetInputType_IfAlreadySetInTaskSettings() { + MatcherAssert.assertThat( + new CohereEmbeddingsTaskSettings(InputType.INGEST, null).setIfAbsent(InputType.SEARCH), + is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)) + ); + } + + public void testSetIfAbsent_DoesNotSetInputType_IfInputTypeIsInvalid() { + MatcherAssert.assertThat( + new CohereEmbeddingsTaskSettings(null, null).setIfAbsent(InputType.UNSPECIFIED), + is(new CohereEmbeddingsTaskSettings(null, null)) + ); + } + + public void testSetIfAbsent_SetsInputType_IfFieldIsNull() { + MatcherAssert.assertThat( + new CohereEmbeddingsTaskSettings(null, null).setIfAbsent(InputType.INGEST), + is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)) + ); + } + @Override protected Writeable.Reader instanceReader() { return CohereEmbeddingsTaskSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index b82812d6c393a..615c831db0114 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -64,7 +65,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep try (var service = new TestService(new SetOnce<>(factory), new SetOnce<>(createWithEmptySettings(threadPool)))) { PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(mockModel, List.of(""), new HashMap<>(), listener); + service.infer(mockModel, List.of(""), new HashMap<>(), InputType.INGEST, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index a76cce41b4fe4..36a4d144d8c5c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -492,7 +493,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var result = listener.actionGet(TIMEOUT); @@ -527,7 +528,7 @@ public void testInfer_SendsElserRequest() throws IOException { var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret"); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 394286ee5287b..2659715771686 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -667,7 +668,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException try (var service = new OpenAiService(new SetOnce<>(factory), new SetOnce<>(createWithEmptySettings(threadPool)))) { PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(mockModel, List.of(""), new HashMap<>(), listener); + service.infer(mockModel, List.of(""), new HashMap<>(), InputType.INGEST, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -713,7 +714,7 @@ public void testInfer_SendsRequest() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var result = listener.actionGet(TIMEOUT); @@ -787,7 +788,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), listener); + service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); From e129e4c2ea9ffc76b7c4af7050417edc0760c565 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 26 Jan 2024 11:08:41 -0500 Subject: [PATCH 2/6] switching logic to allow request to always override --- .../embeddings/CohereEmbeddingsModel.java | 2 +- .../CohereEmbeddingsTaskSettings.java | 8 +- .../services/cohere/CohereServiceTests.java | 16 +++- .../CohereEmbeddingsModelTests.java | 94 ++++++++++++++----- .../CohereEmbeddingsTaskSettingsTests.java | 10 +- 5 files changed, 91 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index b0ff6381696e7..39a7fa6ee240a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -80,6 +80,6 @@ public ExecutableAction accept(CohereActionVisitor visitor, Map public CohereEmbeddingsModel overrideWith(Map taskSettings, InputType inputType) { var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings); - return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings).setIfAbsent(inputType)); + return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings).setInputType(inputType)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index e48661f6ae1e6..a1a93ba232fb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -159,12 +159,12 @@ public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings re } /** - * Sets the input type field for the task settings if it is currently null and the passed in input type value is valid. + * Sets the input type field for the task settings if input type value is valid. * @param inputType the new input type to use - * @return task settings with the values set if they were previously null + * @return newly updated task settings */ - public CohereEmbeddingsTaskSettings setIfAbsent(InputType inputType) { - if (this.inputType != null || VALID_INPUT_VALUES_LIST.contains(inputType) == false) { + public CohereEmbeddingsTaskSettings setInputType(InputType inputType) { + if (VALID_INPUT_VALUES_LIST.contains(inputType) == false) { return this; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 25d5dd74d733a..7daad207f9068 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; @@ -917,7 +918,8 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA } } - public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenPresentInTaskSettingsAsSearch() throws IOException { + public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIsNull_AndRequestTaskSettingsIsSearch() + throws IOException { var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { @@ -952,14 +954,20 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenPresentInTaskS var model = CohereEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), + new CohereEmbeddingsTaskSettings(null, null), 1024, 1024, "model", null ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, List.of("abc"), new HashMap<>(), InputType.INGEST, listener); + service.infer( + model, + List.of("abc"), + CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null), + InputType.INGEST, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -973,7 +981,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenPresentInTaskS MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_query"))); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document"))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index 943a000b0c3aa..bc5de527fe8e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -24,24 +24,36 @@ public class CohereEmbeddingsModelTests extends ESTestCase { - public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_NotRequestIngest() { + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, null); + + var overriddenModel = model.overrideWith(Map.of(), InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, null); + + var overriddenModel = model.overrideWith(null, InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { var model = createModel( "url", "api_key", - new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + new CohereEmbeddingsTaskSettings(null, null), null, null, "model", CohereEmbeddingType.FLOAT ); - // if the request task settings specify an input type value we should honor that instead of the input type that comes from the - // request itself - var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null), InputType.INGEST); + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.INGEST); var expectedModel = createModel( "url", "api_key", - new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model", @@ -50,21 +62,31 @@ public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_No MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } - public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty() { - var model = createModel("url", "api_key", null, null, null); - - var overriddenModel = model.overrideWith(Map.of(), InputType.UNSPECIFIED); - MatcherAssert.assertThat(overriddenModel, is(model)); - } - - public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { - var model = createModel("url", "api_key", null, null, null); + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); - var overriddenModel = model.overrideWith(null, InputType.UNSPECIFIED); - MatcherAssert.assertThat(overriddenModel, is(model)); + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.SEARCH); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } - public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { var model = createModel( "url", "api_key", @@ -75,12 +97,11 @@ public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTask CohereEmbeddingType.FLOAT ); - // since the stored model and request task settings does not have input type defined we'll get it from the request - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.INGEST); + var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.INGEST, null), InputType.SEARCH); var expectedModel = createModel( "url", "api_key", - new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model", @@ -89,7 +110,7 @@ public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTask MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } - public void testOverrideWith_DoesNotSetInputType_FromRequest_IfStoredModelHasInputTypeSet() { + public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { var model = createModel( "url", "api_key", @@ -100,12 +121,11 @@ public void testOverrideWith_DoesNotSetInputType_FromRequest_IfStoredModelHasInp CohereEmbeddingType.FLOAT ); - // since the stored model has the input type field set, we will not set it with the request input type - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.SEARCH); + var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null), InputType.UNSPECIFIED); var expectedModel = createModel( "url", "api_key", - new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model", @@ -138,6 +158,30 @@ public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvali MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } + public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + + var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.UNSPECIFIED); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + public static CohereEmbeddingsModel createModel( String url, String apiKey, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index 120fd4556d1aa..d571309ffb800 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -114,23 +114,23 @@ public void testOverrideWith_UsesOverriddenSettings() { MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.START))); } - public void testSetIfAbsent_DoesNotSetInputType_IfAlreadySetInTaskSettings() { + public void testSetInputType_SetsInputType() { MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(InputType.INGEST, null).setIfAbsent(InputType.SEARCH), - is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)) + new CohereEmbeddingsTaskSettings(InputType.INGEST, null).setInputType(InputType.SEARCH), + is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)) ); } public void testSetIfAbsent_DoesNotSetInputType_IfInputTypeIsInvalid() { MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(null, null).setIfAbsent(InputType.UNSPECIFIED), + new CohereEmbeddingsTaskSettings(null, null).setInputType(InputType.UNSPECIFIED), is(new CohereEmbeddingsTaskSettings(null, null)) ); } public void testSetIfAbsent_SetsInputType_IfFieldIsNull() { MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(null, null).setIfAbsent(InputType.INGEST), + new CohereEmbeddingsTaskSettings(null, null).setInputType(InputType.INGEST), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)) ); } From 4182cd5290593638aeeaea134e13d91aa0e53ad3 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 26 Jan 2024 11:19:31 -0500 Subject: [PATCH 3/6] Fixing failure --- .../xpack/inference/mock/TestInferenceServiceExtension.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java index eee6f68c20ff7..5ffb4b5df08cc 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -123,11 +124,11 @@ public void infer( Model model, List input, Map taskSettings, + InputType inputType, ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY -> listener.onResponse(makeResults(input)); - case SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), From 9ddc9b55d09a005589b025dcbbaca333c57e99d8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 29 Jan 2024 13:08:51 -0500 Subject: [PATCH 4/6] Removing getModelId calls --- .../xpack/inference/action/InferenceActionRequestTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index 9d50bdc886cb0..396af55ce5616 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -145,7 +145,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque if (version.before(TransportVersions.INFERENCE_MULTIPLE_INPUTS)) { return new InferenceAction.Request( instance.getTaskType(), - instance.getModelId(), + instance.getInferenceEntityId(), instance.getInput().subList(0, 1), instance.getTaskSettings(), InputType.UNSPECIFIED @@ -153,7 +153,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { return new InferenceAction.Request( instance.getTaskType(), - instance.getModelId(), + instance.getInferenceEntityId(), instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED @@ -162,7 +162,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque && instance.getInputType() == InputType.UNSPECIFIED) { return new InferenceAction.Request( instance.getTaskType(), - instance.getModelId(), + instance.getInferenceEntityId(), instance.getInput(), instance.getTaskSettings(), InputType.INGEST From e081684a6810200edbc849258a5dd6067a96c083 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 29 Jan 2024 17:10:27 -0500 Subject: [PATCH 5/6] Addressing feedback --- .../inference/action/InferenceAction.java | 10 ++- .../action/cohere/CohereActionCreator.java | 2 +- .../action/openai/OpenAiActionCreator.java | 2 +- .../embeddings/CohereEmbeddingsModel.java | 10 +-- .../CohereEmbeddingsTaskSettings.java | 72 ++++++++++++------- .../embeddings/OpenAiEmbeddingsModel.java | 18 ++--- .../OpenAiEmbeddingsTaskSettings.java | 24 +++++-- .../CohereEmbeddingsModelTests.java | 16 ++--- .../CohereEmbeddingsTaskSettingsTests.java | 52 ++++++-------- .../OpenAiEmbeddingsModelTests.java | 6 +- .../OpenAiEmbeddingsTaskSettingsTests.java | 6 +- 11 files changed, 123 insertions(+), 95 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 2b6d5ccb449a7..2ddba3446d79a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -61,6 +61,8 @@ public static Request parseRequest(String inferenceEntityId, String taskType, XC Request.Builder builder = PARSER.apply(parser, null); builder.setInferenceEntityId(inferenceEntityId); builder.setTaskType(taskType); + // For rest requests we won't know what the input type is + builder.setInputType(InputType.UNSPECIFIED); return builder.build(); } @@ -185,6 +187,7 @@ public static class Builder { private TaskType taskType; private String inferenceEntityId; private List input; + private InputType inputType = InputType.UNSPECIFIED; private Map taskSettings = Map.of(); private Builder() {} @@ -209,13 +212,18 @@ public Builder setInput(List input) { return this; } + public Builder setInputType(InputType inputType) { + this.inputType = inputType; + return this; + } + public Builder setTaskSettings(Map taskSettings) { this.taskSettings = taskSettings; return this; } public Request build() { - return new Request(taskType, inferenceEntityId, input, taskSettings, InputType.UNSPECIFIED); + return new Request(taskType, inferenceEntityId, input, taskSettings, inputType); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index e6b5ceb03f147..0fb5ca9283fae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -30,7 +30,7 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { @Override public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings, InputType inputType) { - var overriddenModel = model.overrideWith(taskSettings, inputType); + var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings, inputType); return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index 6c423760d0b35..94583c634fb26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -29,7 +29,7 @@ public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { @Override public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) { - var overriddenModel = model.overrideWith(taskSettings); + var overriddenModel = OpenAiEmbeddingsModel.of(model, taskSettings); return new OpenAiEmbeddingsAction(sender, overriddenModel, serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 39a7fa6ee240a..a3afdc306b217 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -20,6 +20,11 @@ import java.util.Map; public class CohereEmbeddingsModel extends CohereModel { + public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map taskSettings, InputType inputType) { + var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings); + return new CohereEmbeddingsModel(model, CohereEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)); + } + public CohereEmbeddingsModel( String modelId, TaskType taskType, @@ -77,9 +82,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings, InputType inputType) { return visitor.create(this, taskSettings, inputType); } - - public CohereEmbeddingsModel overrideWith(Map taskSettings, InputType inputType) { - var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings); - return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings).setInputType(inputType)); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index a1a93ba232fb7..7061c80decb19 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -75,6 +75,52 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { return new CohereEmbeddingsTaskSettings(inputType, truncation); } + /** + * Creates a new {@link CohereEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + * + * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link CohereEmbeddingsTaskSettings} + */ + public static CohereEmbeddingsTaskSettings of( + CohereEmbeddingsTaskSettings originalSettings, + CohereEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings); + + return new CohereEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); + } + + private static InputType getValidInputType( + CohereEmbeddingsTaskSettings originalSettings, + CohereEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_INPUT_VALUES_LIST.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private static CohereTruncation getValidTruncation( + CohereEmbeddingsTaskSettings originalSettings, + CohereEmbeddingsTaskSettings requestTaskSettings + ) { + return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation(); + } + private final InputType inputType; private final CohereTruncation truncation; @@ -139,7 +185,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; CohereEmbeddingsTaskSettings that = (CohereEmbeddingsTaskSettings) o; - return inputType == that.inputType && truncation == that.truncation; + return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation); } @Override @@ -147,30 +193,6 @@ public int hashCode() { return Objects.hash(inputType, truncation); } - public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { - if (requestTaskSettings.equals(EMPTY_SETTINGS)) { - return this; - } - - var inputTypeToUse = requestTaskSettings.getInputType() == null ? inputType : requestTaskSettings.getInputType(); - var truncationToUse = requestTaskSettings.getTruncation() == null ? truncation : requestTaskSettings.getTruncation(); - - return new CohereEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); - } - - /** - * Sets the input type field for the task settings if input type value is valid. - * @param inputType the new input type to use - * @return newly updated task settings - */ - public CohereEmbeddingsTaskSettings setInputType(InputType inputType) { - if (VALID_INPUT_VALUES_LIST.contains(inputType) == false) { - return this; - } - - return new CohereEmbeddingsTaskSettings(inputType, truncation); - } - public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 98b0161665d8e..74d97099bbb76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -21,6 +21,15 @@ public class OpenAiEmbeddingsModel extends OpenAiModel { + public static OpenAiEmbeddingsModel of(OpenAiEmbeddingsModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + var requestTaskSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings); + return new OpenAiEmbeddingsModel(model, OpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + public OpenAiEmbeddingsModel( String inferenceEntityId, TaskType taskType, @@ -78,13 +87,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(OpenAiActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); } - - public OpenAiEmbeddingsModel overrideWith(Map taskSettings) { - if (taskSettings == null || taskSettings.isEmpty()) { - return this; - } - - var requestTaskSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings); - return new OpenAiEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings)); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java index 45a9ce1cabbc3..c6f3179a4f088 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java @@ -50,6 +50,23 @@ public static OpenAiEmbeddingsTaskSettings fromMap(Map map) { return new OpenAiEmbeddingsTaskSettings(model, user); } + /** + * Creates a new {@link OpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * @param originalSettings the original task settings from the inference entity configuration from storage + * @param requestSettings the task settings from the request + * @return a new {@link OpenAiEmbeddingsTaskSettings} + */ + public static OpenAiEmbeddingsTaskSettings of( + OpenAiEmbeddingsTaskSettings originalSettings, + OpenAiEmbeddingsRequestTaskSettings requestSettings + ) { + var modelToUse = requestSettings.model() == null ? originalSettings.model : requestSettings.model(); + var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); + + return new OpenAiEmbeddingsTaskSettings(modelToUse, userToUse); + } + public OpenAiEmbeddingsTaskSettings { Objects.requireNonNull(model); } @@ -84,11 +101,4 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(model); out.writeOptionalString(user); } - - public OpenAiEmbeddingsTaskSettings overrideWith(OpenAiEmbeddingsRequestTaskSettings requestSettings) { - var modelToUse = requestSettings.model() == null ? model : requestSettings.model(); - var userToUse = requestSettings.user() == null ? user : requestSettings.user(); - - return new OpenAiEmbeddingsTaskSettings(modelToUse, userToUse); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index bc5de527fe8e5..5570731dbe8d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -27,14 +27,14 @@ public class CohereEmbeddingsModelTests extends ESTestCase { public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { var model = createModel("url", "api_key", null, null, null); - var overriddenModel = model.overrideWith(Map.of(), InputType.UNSPECIFIED); + var overriddenModel = CohereEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED); MatcherAssert.assertThat(overriddenModel, is(model)); } public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { var model = createModel("url", "api_key", null, null, null); - var overriddenModel = model.overrideWith(null, InputType.UNSPECIFIED); + var overriddenModel = CohereEmbeddingsModel.of(model, null, InputType.UNSPECIFIED); MatcherAssert.assertThat(overriddenModel, is(model)); } @@ -49,7 +49,7 @@ public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTask CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.INGEST); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.INGEST); var expectedModel = createModel( "url", "api_key", @@ -73,7 +73,7 @@ public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredT CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.SEARCH); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.SEARCH); var expectedModel = createModel( "url", "api_key", @@ -97,7 +97,7 @@ public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequest CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.INGEST, null), InputType.SEARCH); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST, null), InputType.SEARCH); var expectedModel = createModel( "url", "api_key", @@ -121,7 +121,7 @@ public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_Wh CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null), InputType.UNSPECIFIED); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH, null), InputType.UNSPECIFIED); var expectedModel = createModel( "url", "api_key", @@ -145,7 +145,7 @@ public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvali CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.UNSPECIFIED); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED); var expectedModel = createModel( "url", "api_key", @@ -169,7 +169,7 @@ public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_A CohereEmbeddingType.FLOAT ); - var overriddenModel = model.overrideWith(getTaskSettingsMap(null, null), InputType.UNSPECIFIED); + var overriddenModel = CohereEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED); var expectedModel = createModel( "url", "api_key", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index d571309ffb800..77e3280d18f93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; @@ -92,47 +91,36 @@ public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { MatcherAssert.assertThat(thrownException.getMessage(), CoreMatchers.is("received invalid input type value [unspecified]")); } - public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { - var taskSettings = CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereServiceSettings.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) + public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() { + var taskSettings = new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.NONE); + var overriddenTaskSettings = CohereEmbeddingsTaskSettings.of( + taskSettings, + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + InputType.UNSPECIFIED ); - - var overriddenTaskSettings = taskSettings.overrideWith(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); } - public void testOverrideWith_UsesOverriddenSettings() { - var taskSettings = CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) - ); - - var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.START.toString())) + public void testOf_UsesRequestTaskSettings() { + var taskSettings = new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE); + var overriddenTaskSettings = CohereEmbeddingsTaskSettings.of( + taskSettings, + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.END), + InputType.UNSPECIFIED ); - var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); - MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.START))); - } - - public void testSetInputType_SetsInputType() { - MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(InputType.INGEST, null).setInputType(InputType.SEARCH), - is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)) - ); + MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.END))); } - public void testSetIfAbsent_DoesNotSetInputType_IfInputTypeIsInvalid() { - MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(null, null).setInputType(InputType.UNSPECIFIED), - is(new CohereEmbeddingsTaskSettings(null, null)) + public void testOf_UsesRequestTaskSettings_AndRequestInputType() { + var taskSettings = new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE); + var overriddenTaskSettings = CohereEmbeddingsTaskSettings.of( + taskSettings, + new CohereEmbeddingsTaskSettings(null, CohereTruncation.END), + InputType.INGEST ); - } - public void testSetIfAbsent_SetsInputType_IfFieldIsNull() { - MatcherAssert.assertThat( - new CohereEmbeddingsTaskSettings(null, null).setInputType(InputType.INGEST), - is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)) - ); + MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.END))); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java index 10e856ec8a27e..e2144132af6c1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java @@ -27,7 +27,7 @@ public void testOverrideWith_OverridesUser() { var model = createModel("url", "org", "api_key", "model_name", null); var requestTaskSettingsMap = getRequestTaskSettingsMap(null, "user_override"); - var overriddenModel = model.overrideWith(requestTaskSettingsMap); + var overriddenModel = OpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); assertThat(overriddenModel, is(createModel("url", "org", "api_key", "model_name", "user_override"))); } @@ -37,14 +37,14 @@ public void testOverrideWith_EmptyMap() { var requestTaskSettingsMap = Map.of(); - var overriddenModel = model.overrideWith(requestTaskSettingsMap); + var overriddenModel = OpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); assertThat(overriddenModel, sameInstance(model)); } public void testOverrideWith_NullMap() { var model = createModel("url", "org", "api_key", "model_name", null); - var overriddenModel = model.overrideWith(null); + var overriddenModel = OpenAiEmbeddingsModel.of(model, null); assertThat(overriddenModel, sameInstance(model)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java index f297eb622c421..103fab071098e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java @@ -72,7 +72,7 @@ public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model", OpenAiEmbeddingsTaskSettings.USER, "user")) ); - var overriddenTaskSettings = taskSettings.overrideWith(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS); + var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS); MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); } @@ -85,7 +85,7 @@ public void testOverrideWith_UsesOverriddenSettings() { new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model2", OpenAiEmbeddingsTaskSettings.USER, "user2")) ); - var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); + var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user2"))); } @@ -98,7 +98,7 @@ public void testOverrideWith_UsesOnlyNonNullModelSetting() { new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model2")) ); - var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); + var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user"))); } From 6ab8503a178d1281ef153869da235c03a8bdbc8e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 30 Jan 2024 12:08:36 -0500 Subject: [PATCH 6/6] Switching to enumset --- .../inference/services/ServiceUtils.java | 20 +++++++++---------- .../CohereEmbeddingsServiceSettings.java | 3 ++- .../CohereEmbeddingsTaskSettings.java | 14 ++++++------- .../inference/services/ServiceUtilsTests.java | 9 +++++---- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 63dc28ce3c12b..7637bd9740670 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -24,7 +24,7 @@ import java.net.URI; import java.net.URISyntaxException; -import java.util.Arrays; +import java.util.EnumSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -221,12 +221,12 @@ public static String extractOptionalString( return optionalField; } - public static T extractOptionalEnum( + public static > E extractOptionalEnum( Map map, String settingName, String scope, - EnumConstructor constructor, - T[] validValues, + EnumConstructor constructor, + EnumSet validValues, ValidationException validationException ) { var enumString = extractOptionalString(map, settingName, scope, validationException); @@ -234,7 +234,7 @@ public static T extractOptionalEnum( return null; } - var validValuesAsStrings = Arrays.stream(validValues).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); try { var createdEnum = constructor.apply(enumString); validateEnumValue(createdEnum, validValues); @@ -247,19 +247,19 @@ public static T extractOptionalEnum( return null; } - private static void validateEnumValue(T enumValue, T[] validValues) { - if (Arrays.asList(validValues).contains(enumValue) == false) { + private static > void validateEnumValue(E enumValue, EnumSet validValues) { + if (validValues.contains(enumValue) == false) { throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString())); } } /** * Functional interface for creating an enum from a string. - * @param + * @param */ @FunctionalInterface - public interface EnumConstructor { - T apply(String name) throws IllegalArgumentException; + public interface EnumConstructor> { + E apply(String name) throws IllegalArgumentException; } public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java index 5327bcbcf22dd..916e7fadcc8fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import java.io.IOException; +import java.util.EnumSet; import java.util.Map; import java.util.Objects; @@ -37,7 +38,7 @@ public static CohereEmbeddingsServiceSettings fromMap(Map map) { EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, CohereEmbeddingType::fromString, - CohereEmbeddingType.values(), + EnumSet.allOf(CohereEmbeddingType.class), validationException ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 7061c80decb19..b294350580a2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -21,8 +21,7 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; -import java.util.Arrays; -import java.util.List; +import java.util.EnumSet; import java.util.Map; import java.util.Objects; @@ -41,8 +40,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "cohere_embeddings_task_settings"; public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null); static final String INPUT_TYPE = "input_type"; - private static final InputType[] VALID_REQUEST_VALUES = { InputType.INGEST, InputType.SEARCH }; - private static final List VALID_INPUT_VALUES_LIST = Arrays.asList(VALID_REQUEST_VALUES); + private static final EnumSet VALID_REQUEST_VALUES2 = EnumSet.of(InputType.INGEST, InputType.SEARCH); public static CohereEmbeddingsTaskSettings fromMap(Map map) { if (map == null || map.isEmpty()) { @@ -56,7 +54,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, InputType::fromString, - VALID_REQUEST_VALUES, + VALID_REQUEST_VALUES2, validationException ); CohereTruncation truncation = extractOptionalEnum( @@ -64,7 +62,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { TRUNCATE, ModelConfigurations.TASK_SETTINGS, CohereTruncation::fromString, - CohereTruncation.values(), + EnumSet.allOf(CohereTruncation.class), validationException ); @@ -105,7 +103,7 @@ private static InputType getValidInputType( ) { InputType inputTypeToUse = originalSettings.inputType; - if (VALID_INPUT_VALUES_LIST.contains(requestInputType)) { + if (VALID_REQUEST_VALUES2.contains(requestInputType)) { inputTypeToUse = requestInputType; } else if (requestTaskSettings.inputType != null) { inputTypeToUse = requestTaskSettings.inputType; @@ -139,7 +137,7 @@ private static void validateInputType(InputType inputType) { return; } - assert VALID_INPUT_VALUES_LIST.contains(inputType) : invalidInputTypeMessage(inputType); + assert VALID_REQUEST_VALUES2.contains(inputType) : invalidInputTypeMessage(inputType); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index dafe12f21bf3d..689c9f9b08a2b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -261,7 +262,7 @@ public void testExtractOptionalString_AddsException_WhenFieldIsEmpty() { public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", "value")); - var createdEnum = extractOptionalEnum(map, "abc", "scope", InputType::fromString, InputType.values(), validation); + var createdEnum = extractOptionalEnum(map, "abc", "scope", InputType::fromString, EnumSet.allOf(InputType.class), validation); assertNull(createdEnum); assertTrue(validation.validationErrors().isEmpty()); @@ -276,7 +277,7 @@ public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidVal "key", "scope", InputType::fromString, - new InputType[] { InputType.INGEST, InputType.SEARCH }, + EnumSet.of(InputType.INGEST, InputType.SEARCH), validation ); @@ -292,7 +293,7 @@ public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidVal public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenValueIsNotPartOfTheAcceptableValues() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", InputType.UNSPECIFIED.toString())); - var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, new InputType[] { InputType.INGEST }, validation); + var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, EnumSet.of(InputType.INGEST), validation); assertNull(createdEnum); assertFalse(validation.validationErrors().isEmpty()); @@ -303,7 +304,7 @@ public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenValueIsNotPa public void testExtractOptionalEnum_ReturnsIngest_WhenValueIsAcceptable() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", InputType.INGEST.toString())); - var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, new InputType[] { InputType.INGEST }, validation); + var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, EnumSet.of(InputType.INGEST), validation); assertThat(createdEnum, is(InputType.INGEST)); assertTrue(validation.validationErrors().isEmpty());