From 244da7584fbaa2618a4befbb37b13f3173d76e83 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 10 Jan 2024 14:00:15 -0500 Subject: [PATCH 01/13] Starting cohere --- .../cohere/CohereEmbeddingsRequest.java | 41 +++++++++++ .../inference/services/ServiceUtils.java | 71 +++++++++++++++++++ .../cohere/CohereEmbeddingsTaskSettings.java | 60 ++++++++++++++++ .../services/cohere/CohereServiceFields.java | 14 ++++ .../services/cohere/CohereTruncation.java | 58 +++++++++++++++ 5 files changed, 244 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java new file mode 100644 index 0000000000000..093ce9fd2f4bb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.net.URI; + +public class CohereEmbeddingsRequest implements Request { + + public CohereEmbeddingsRequest(Truncator truncator, Object account, Truncator.TruncationResult input, Object taskSettings) { + + } + + @Override + public HttpRequestBase createRequest() { + return null; + } + + @Override + public URI getURI() { + return null; + } + + @Override + public Request truncate() { + return null; + } + + @Override + public boolean[] getTruncationInfo() { + return new boolean[0]; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1686cd32d4a6b..b503c0800d342 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -21,6 +22,7 @@ import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -105,6 +107,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } + public static String invalidType(String settingName, String scope, String invalidType, String requiredType) { + return Strings.format("[%s] Invalid type [%s] received. [%s] must be type [%s]", scope, invalidType, settingName, requiredType); + } + // TODO improve URI validation logic public static URI convertToUri(String url, String settingName, String settingScope, ValidationException validationException) { try { @@ -154,6 +160,42 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin return null; } + @SuppressWarnings("unchecked") + public static List extractOptionalListOfType( + Map map, + String settingName, + String scope, + Class type, + ValidationException validationException + ) { + List listField = ServiceUtils.removeAsType(map, settingName, List.class); + + if (listField == null) { + return null; + } + + if (listField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope)); + return null; + } + + List castedList = new ArrayList<>(listField.size()); + + for (Object listEntry : listField) { + if (type.isAssignableFrom(listEntry.getClass()) == false) { + // TODO should we just throw here like removeAsType + validationException.addValidationError( + invalidType(settingName, scope, listEntry.getClass().getSimpleName(), type.getSimpleName()) + ); + return null; + } + + castedList.add((T) listEntry); + } + + return castedList; + } + public static String extractRequiredString( Map map, String settingName, @@ -194,6 +236,35 @@ public static String extractOptionalString( return optionalField; } + public static T extractOptionalEnum( + Map map, + String settingName, + String scope, + CheckedFunction converter, + ValidationException validationException + ) { + var s = extractOptionalString(map, settingName, scope, validationException); + if (s == null) { + return null; + } + + try { + var e = converter.apply(s); + } catch (IllegalArgumentException e) { + validationException.addValidationError(invalidType(settingName, scope, s, )) + } + + if (s.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope)); + } + + if (validationException.validationErrors().isEmpty() == false) { + return null; + } + + return optionalField; + } + public static String parsePersistedConfigErrorMsg(String modelId, String serviceName) { return format("Failed to parse stored model [%s] for [%s] service, please delete and add the service again", modelId, serviceName); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..229d81e468d3f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; + +/** + * Defines the task settings for the cohere text embeddings service. + * + * @param model the id of the model to use in the requests to cohere + * @param inputType Specifies the type of input you're giving to the model + * @param embeddingTypes Specifies the types of embeddings you want to get back + * @param truncate Specifies how the API will handle inputs longer than the maximum token length + */ +public record CohereEmbeddingsTaskSettings( + @Nullable String model, + @Nullable String inputType, + @Nullable List embeddingTypes, + @Nullable CohereTruncation truncate +) { + + public static final String NAME = "cohere_embeddings_task_settings"; + static final String INPUT_TYPE = "input_type"; + static final String EMBEDDING_TYPES = "embedding_types"; + + public static CohereEmbeddingsTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); + String inputType = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException); + List embeddingTypes = extractOptionalListOfType( + map, + EMBEDDING_TYPES, + ModelConfigurations.TASK_SETTINGS, + String.class, + validationException + ); + CohereTruncation truncation = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CohereEmbeddingsTaskSettings(model, user); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java new file mode 100644 index 0000000000000..ccfe1cb2593c6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +public class CohereServiceFields { + public static final String MODEL = "model"; + public static final String TRUNCATE = "truncate"; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java new file mode 100644 index 0000000000000..ebf1d349e0b7a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +/** + * Defines the type of truncation for a cohere request. The specified value determines how the Cohere API will handle inputs + * longer than the maximum token length. + * + *

+ * See api docs for details. + *

+ */ +public enum CohereTruncation implements Writeable { + /** + * When the input exceeds the maximum input token length an error will be returned. + */ + NONE, + /** + * Discard the start of the input + */ + START, + /** + * Discard the end of the input + */ + END; + + public static String NAME = "cohere_truncate"; + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public static CohereTruncation fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static CohereTruncation fromStream(StreamInput in) throws IOException { + return in.readEnum(CohereTruncation.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } +} From e1008cf3c1d63aa78658e9c64e77f4df503742ef Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 12 Jan 2024 16:57:03 -0500 Subject: [PATCH 02/13] Making progress on cohere --- .../org/elasticsearch/TransportVersions.java | 1 + .../elasticsearch/inference/InputType.java | 8 +- .../org/elasticsearch/inference/Model.java | 14 ++ .../inference/ModelConfigurations.java | 26 +++ .../inference/src/main/java/module-info.java | 5 - .../action/cohere/CohereEmbeddingsAction.java | 89 +++++++++ .../action/cohere/OpenAiActionCreator.java | 36 ++++ .../action/cohere/OpenAiActionVisitor.java | 17 ++ .../external/cohere/CohereAccount.java | 21 +++ .../cohere/CohereResponseHandler.java | 99 ++++++++++ .../http/retry/ResponseHandlerUtils.java | 22 +++ .../openai/OpenAiResponseHandler.java | 10 +- .../external/request/RequestUtils.java | 19 ++ .../cohere/CohereEmbeddingsRequest.java | 60 +++++- .../cohere/CohereEmbeddingsRequestEntity.java | 54 ++++++ .../external/request/cohere/CohereUtils.java | 16 ++ .../openai/OpenAiEmbeddingsRequest.java | 13 +- .../CohereEmbeddingsResponseEntity.java | 117 ++++++++++++ .../cohere/CohereErrorResponseEntity.java | 58 ++++++ .../inference/services/ServiceUtils.java | 57 ++++-- .../cohere/CohereEmbeddingsTaskSettings.java | 60 ------ .../services/cohere/CohereModel.java | 31 ++++ .../cohere/CohereServiceSettings.java | 171 ++++++++++++++++++ .../services/cohere/CohereTruncation.java | 4 +- .../embeddings/CohereEmbeddingsModel.java | 79 ++++++++ .../CohereEmbeddingsTaskSettings.java | 138 ++++++++++++++ .../services/openai/OpenAiModel.java | 10 + .../openai/OpenAiServiceSettings.java | 10 +- .../embeddings/OpenAiEmbeddingsModel.java | 22 +-- .../cohere/CohereServiceSettingsTests.java | 151 ++++++++++++++++ .../CohereEmbeddingsModelTests.java | 41 +++++ .../CohereEmbeddingsTaskSettingsTests.java | 112 ++++++++++++ .../OpenAiEmbeddingsTaskSettingsTests.java | 15 +- 33 files changed, 1440 insertions(+), 146 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f289a7a3c89a1..a872b5ceb3496 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -182,6 +182,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_PLAN_POINT_LITERAL_WKB = def(8_570_00_0); public static final TransportVersion HOT_THREADS_AS_BYTES = def(8_571_00_0); public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED = def(8_572_00_0); + public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_573_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index f8bbea4ae121f..b5ba3f0d1b506 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -29,12 +29,16 @@ public String toString() { return name().toLowerCase(Locale.ROOT); } + public static InputType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + public static InputType fromStream(StreamInput in) throws IOException { - return in.readEnum(InputType.class); + return in.readOptionalEnum(InputType.class); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeEnum(this); + out.writeOptionalEnum(this); } } diff --git a/server/src/main/java/org/elasticsearch/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java index 02be39d8a653d..e5f1d7936e605 100644 --- a/server/src/main/java/org/elasticsearch/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) { this.secrets = Objects.requireNonNull(secrets); } + public Model(Model model, TaskSettings taskSettings) { + Objects.requireNonNull(model); + + configurations = ModelConfigurations.of(model, taskSettings); + secrets = model.getSecrets(); + } + + public Model(Model model, ServiceSettings serviceSettings) { + Objects.requireNonNull(model); + + configurations = ModelConfigurations.of(model, serviceSettings); + secrets = model.getSecrets(); + } + public Model(ModelConfigurations configurations) { this(configurations, new ModelSecrets()); } diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java index cdccca7eb0c0e..e91f373d55d37 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java @@ -27,6 +27,32 @@ public class ModelConfigurations implements ToXContentObject, VersionedNamedWrit public static final String TASK_SETTINGS = "task_settings"; private static final String NAME = "inference_model"; + public static ModelConfigurations of(Model model, TaskSettings taskSettings) { + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + + return new ModelConfigurations( + model.getConfigurations().getModelId(), + model.getConfigurations().getTaskType(), + model.getConfigurations().getService(), + model.getServiceSettings(), + taskSettings + ); + } + + public static ModelConfigurations of(Model model, ServiceSettings serviceSettings) { + Objects.requireNonNull(model); + Objects.requireNonNull(serviceSettings); + + return new ModelConfigurations( + model.getConfigurations().getModelId(), + model.getConfigurations().getTaskType(), + model.getConfigurations().getService(), + serviceSettings, + model.getTaskSettings() + ); + } + private final String modelId; private final TaskType taskType; private final String service; diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 3879a0a344e06..2d25a48117778 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -22,10 +22,5 @@ exports org.elasticsearch.xpack.inference.registry; exports org.elasticsearch.xpack.inference.rest; exports org.elasticsearch.xpack.inference.services; - exports org.elasticsearch.xpack.inference.external.http.sender; - exports org.elasticsearch.xpack.inference.external.http; - exports org.elasticsearch.xpack.inference.services.elser; - exports org.elasticsearch.xpack.inference.services.huggingface.elser; - exports org.elasticsearch.xpack.inference.services.openai; exports org.elasticsearch.xpack.inference; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java new file mode 100644 index 0000000000000..569f33b256803 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.cohere; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.external.cohere.CohereResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; +import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.cohere.CohereEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; + +import java.net.URI; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class CohereEmbeddingsAction implements ExecutableAction { + private static final Logger logger = LogManager.getLogger(CohereEmbeddingsAction.class); + + private final CohereAccount account; + private final CohereEmbeddingsModel model; + private final String errorMessage; + private final Truncator truncator; + private final RetryingHttpSender sender; + + public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey()); + this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri()); + this.truncator = Objects.requireNonNull(serviceComponents.truncator()); + this.sender = new RetryingHttpSender( + Objects.requireNonNull(sender), + serviceComponents.throttlerManager(), + logger, + new RetrySettings(serviceComponents.settings()), + serviceComponents.threadPool() + ); + } + + private static String getErrorMessage(@Nullable URI uri) { + if (uri != null) { + return format("Failed to send Cohere embeddings request to [%s]", uri.toString()); + } + + return "Failed to send Cohere embeddings request"; + } + + @Override + public void execute(List input, ActionListener listener) { + try { + // TODO only truncate if the setting is NONE? + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + + CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(truncator, account, truncatedInput, model.getTaskSettings()); + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(request, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } + + private static ResponseHandler createEmbeddingsHandler() { + return new CohereResponseHandler("cohere text embedding", OpenAiEmbeddingsResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java new file mode 100644 index 0000000000000..0353922959e0b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.cohere; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. + */ +public class OpenAiActionCreator implements OpenAiActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) { + var overriddenModel = model.overrideWith(taskSettings); + + return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java new file mode 100644 index 0000000000000..a3a4cfbbfc873 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.cohere; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; + +import java.util.Map; + +public interface OpenAiActionVisitor { + ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java new file mode 100644 index 0000000000000..9847d496d14ee --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; + +import java.net.URI; +import java.util.Objects; + +public record CohereAccount(@Nullable URI url, SecureString apiKey) { + + public CohereAccount { + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java new file mode 100644 index 0000000000000..299cad20fb012 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; +import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; + +public class CohereResponseHandler extends BaseResponseHandler { + + static final String MONTHLY_REQUESTS_LIMIT = "x-endpoint-monthly-call-limit"; + // TODO determine the production versions of these + static final String TRIAL_REQUEST_LIMIT_PER_MINUTE = "x-trial-endpoint-call-limit"; + static final String TRIAL_REQUESTS_REMAINING = "x-trial-endpoint-call-remaining"; + static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; + static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; + + public CohereResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, HttpRequestBase request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * The OpenAI API error codes are documented here. + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + void checkForFailureStatusCode(HttpRequestBase request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode >= 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); + } else if (isTextsArrayTooLarge(result)) { + throw new RetryException(false, buildError(TEXTS_ARRAY_ERROR_MESSAGE, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } + + private static boolean isTextsArrayTooLarge(HttpResult result) { + int statusCode = result.response().getStatusLine().getStatusCode(); + + if (statusCode == 400) { + var errorEntity = CohereErrorResponseEntity.fromResponse(result); + return errorEntity != null && errorEntity.getErrorMessage().contains(TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER); + } + + return false; + } + + static String buildRateLimitErrorMessage(HttpResult result) { + var response = result.response(); + var monthlyRequestLimit = getFirstHeaderOrUnknown(response, MONTHLY_REQUESTS_LIMIT); + var trialRequestsPerMinute = getFirstHeaderOrUnknown(response, TRIAL_REQUEST_LIMIT_PER_MINUTE); + var trialRequestsRemaining = getFirstHeaderOrUnknown(response, TRIAL_REQUESTS_REMAINING); + + var usageMessage = Strings.format( + "Monthly request limit [%s], permitted requests per minute [%s], remaining requests [%s]", + monthlyRequestLimit, + trialRequestsPerMinute, + trialRequestsRemaining + ); + + return RATE_LIMIT + ". " + usageMessage; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java new file mode 100644 index 0000000000000..6269c81d4ceb7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.apache.http.HttpResponse; + +public class ResponseHandlerUtils { + public static String getFirstHeaderOrUnknown(HttpResponse response, String name) { + var header = response.getFirstHeader(name); + if (header != null && header.getElements().length > 0) { + return header.getElements()[0].getName(); + } + return "unknown"; + } + + private ResponseHandlerUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java index 207e3c2bbd035..6e54ba1bd90b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.openai; -import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpRequestBase; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.Strings; @@ -20,6 +19,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; +import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; public class OpenAiResponseHandler extends BaseResponseHandler { /** @@ -110,12 +110,4 @@ static String buildRateLimitErrorMessage(HttpResult result) { return RATE_LIMIT + ". " + usageMessage; } - - private static String getFirstHeaderOrUnknown(HttpResponse response, String name) { - var header = response.getFirstHeader(name); - if (header != null && header.getElements().length > 0) { - return header.getElements()[0].getName(); - } - return "unknown"; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java index 355db7288dacc..4cb32b7bc95fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java @@ -10,7 +10,14 @@ import org.apache.http.Header; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; + +import java.net.URI; +import java.net.URISyntaxException; public class RequestUtils { @@ -18,5 +25,17 @@ public static Header createAuthBearerHeader(SecureString apiKey) { return new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.toString()); } + public static URI buildUri(URI accountUri, String service, CheckedSupplier uriBuilder) { + try { + return accountUri == null ? uriBuilder.get() : accountUri; + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", service), + RestStatus.INTERNAL_SERVER_ERROR, + e + ); + } + } + private RequestUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java index 093ce9fd2f4bb..7f2d34e1f8b3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java @@ -7,35 +7,85 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; public class CohereEmbeddingsRequest implements Request { - public CohereEmbeddingsRequest(Truncator truncator, Object account, Truncator.TruncationResult input, Object taskSettings) { + private final Truncator truncator; + private final CohereAccount account; + private final Truncator.TruncationResult truncationResult; + private final URI uri; + private final CohereEmbeddingsTaskSettings taskSettings; + public CohereEmbeddingsRequest( + Truncator truncator, + CohereAccount account, + Truncator.TruncationResult input, + CohereEmbeddingsTaskSettings taskSettings + ) { + this.truncator = Objects.requireNonNull(truncator); + this.account = Objects.requireNonNull(account); + this.truncationResult = Objects.requireNonNull(input); + this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri); + this.taskSettings = Objects.requireNonNull(taskSettings); } @Override public HttpRequestBase createRequest() { - return null; + HttpPost httpPost = new HttpPost(uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new CohereEmbeddingsRequestEntity(truncationResult.input(), taskSettings)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + return httpPost; } @Override public URI getURI() { - return null; + return uri; } @Override public Request truncate() { - return null; + // TODO only do this is the truncate setting is NONE? + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings); } @Override public boolean[] getTruncationInfo() { - return new boolean[0]; + return truncationResult.truncated().clone(); + } + + // default for testing + static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) + .build(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..8aea434d7339d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject { + + private static final String TEXTS_FIELD = "texts"; + + static final String INPUT_TYPE_FIELD = "input_type"; + static final String EMBEDDING_TYPES_FIELD = "embedding_types"; + + public CohereEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + if (taskSettings.model() != null) { + builder.field(CohereServiceFields.MODEL, taskSettings.model()); + } + + if (taskSettings.inputType() != null) { + builder.field(INPUT_TYPE_FIELD, taskSettings.inputType()); + } + + if (taskSettings.embeddingTypes() != null) { + builder.field(EMBEDDING_TYPES_FIELD, taskSettings.embeddingTypes()); + } + + if (taskSettings.truncation() != null) { + builder.field(EMBEDDING_TYPES_FIELD, taskSettings.truncation()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java new file mode 100644 index 0000000000000..f8ccd91d4e3d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +public class CohereUtils { + public static final String HOST = "api.cohere.ai"; + public static final String VERSION_1 = "v1"; + public static final String EMBEDDINGS_PATH = "embed"; + + private CohereUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java index 3a9fab44aa04e..9ead692b9e110 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java @@ -12,9 +12,7 @@ import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; @@ -26,6 +24,7 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; @@ -46,18 +45,10 @@ public OpenAiEmbeddingsRequest( this.truncator = Objects.requireNonNull(truncator); this.account = Objects.requireNonNull(account); this.truncationResult = Objects.requireNonNull(input); - this.uri = buildUri(this.account.url()); + this.uri = buildUri(this.account.url(), "OpenAI", OpenAiEmbeddingsRequest::buildDefaultUri); this.taskSettings = Objects.requireNonNull(taskSettings); } - private static URI buildUri(URI accountUri) { - try { - return accountUri == null ? buildDefaultUri() : accountUri; - } catch (URISyntaxException e) { - throw new ElasticsearchStatusException("Failed to construct OpenAI URL", RestStatus.INTERNAL_SERVER_ERROR, e); - } - } - public HttpRequestBase createRequest() { HttpPost httpPost = new HttpPost(uri); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..cbe70ad8d919a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.cohere; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class CohereEmbeddingsResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response"; + + /** + * Parses the OpenAI json response. + * For a request like: + * + *
+     * 
+     * {
+     *  "texts": ["hello this is my name", "I wish I was there!"]
+     * }
+     * 
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *  "id": "da4f9ea6-37e4-41ab-b5e1-9e2985609555",
+     *  "texts": [
+     *      "hello",
+     *      "awesome"
+     *  ],
+     *  "embeddings": [
+     *      [
+     *          123
+     *      ],
+     *      [
+     *          123
+     *      ]
+     *  ],
+     *  "meta": {
+     *      "api_version": {
+     *          "version": "1"
+     *      },
+     *      "warnings": [
+     *          "default model on embed will be deprecated in the future, please specify a model in the request."
+     *      ],
+     *      "billed_units": {
+     *          "input_tokens": 3
+     *      }
+     *  },
+     *  "response_type": "embeddings_floats"
+     * }
+     * 
+     * 
+ */ + public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = XContentParserUtils.parseList( + jsonParser, + CohereEmbeddingsResponseEntity::parseEmbeddingObject + ); + + return new TextEmbeddingResults(embeddingList); + } + } + + private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingList); + + // the parser is currently sitting at an ARRAY_END so go to the next token + parser.nextToken(); + // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array + parser.skipChildren(); + + return new TextEmbeddingResults.Embedding(embeddingValues); + } + + private static float parseEmbeddingList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private CohereEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java new file mode 100644 index 0000000000000..6f947b45955be --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.cohere; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +public class CohereErrorResponseEntity implements ErrorMessage { + + private final String errorMessage; + + private CohereErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response for invalid auth would look like + * + * { + * "message": "invalid request: total number of texts must be at most 96 - received 97" + * } + * + * + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the message field + */ + public static CohereErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var message = (String) responseMap.get("message"); + if (message != null) { + return new CohereErrorResponseEntity(message); + } + } catch (Exception e) { + // swallow the error + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index b503c0800d342..6fef9dac6095a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -23,7 +23,9 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -107,8 +109,29 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } - public static String invalidType(String settingName, String scope, String invalidType, String requiredType) { - return Strings.format("[%s] Invalid type [%s] received. [%s] must be type [%s]", scope, invalidType, settingName, requiredType); + public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) { + return Strings.format( + "[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]", + scope, + invalidType, + invalidValue, + settingName, + requiredType + ); + } + + public static String invalidValue(String settingName, String scope, String invalidValue, String validValue) { + return invalidValue(settingName, scope, invalidValue, new String[] { validValue }); + } + + public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) { + return Strings.format( + "[%s] Invalid value [%s] received. [%s] must be one of [%s]", + scope, + invalidType, + settingName, + String.join(", ", requiredTypes) + ); } // TODO improve URI validation logic @@ -131,6 +154,14 @@ public static URI createUri(String url) throws IllegalArgumentException { } } + public static URI createOptionalUri(String url) { + if (url == null) { + return null; + } + + return createUri(url); + } + public static SecureString extractRequiredSecureString( Map map, String settingName, @@ -185,7 +216,7 @@ public static List extractOptionalListOfType( if (type.isAssignableFrom(listEntry.getClass()) == false) { // TODO should we just throw here like removeAsType validationException.addValidationError( - invalidType(settingName, scope, listEntry.getClass().getSimpleName(), type.getSimpleName()) + invalidType(settingName, scope, listEntry.getClass().getSimpleName(), listEntry.toString(), type.getSimpleName()) ); return null; } @@ -241,28 +272,22 @@ public static T extractOptionalEnum( String settingName, String scope, CheckedFunction converter, + T[] validTypes, ValidationException validationException ) { - var s = extractOptionalString(map, settingName, scope, validationException); - if (s == null) { + var enumString = extractOptionalString(map, settingName, scope, validationException); + if (enumString == null) { return null; } + var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); try { - var e = converter.apply(s); + return converter.apply(enumString); } catch (IllegalArgumentException e) { - validationException.addValidationError(invalidType(settingName, scope, s, )) - } - - if (s.isEmpty()) { - validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope)); - } - - if (validationException.validationErrors().isEmpty() == false) { - return null; + validationException.addValidationError(invalidValue(settingName, scope, enumString, validTypesAsStrings)); } - return optionalField; + return null; } public static String parsePersistedConfigErrorMsg(String modelId, String serviceName) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java deleted file mode 100644 index 229d81e468d3f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; - -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; - -/** - * Defines the task settings for the cohere text embeddings service. - * - * @param model the id of the model to use in the requests to cohere - * @param inputType Specifies the type of input you're giving to the model - * @param embeddingTypes Specifies the types of embeddings you want to get back - * @param truncate Specifies how the API will handle inputs longer than the maximum token length - */ -public record CohereEmbeddingsTaskSettings( - @Nullable String model, - @Nullable String inputType, - @Nullable List embeddingTypes, - @Nullable CohereTruncation truncate -) { - - public static final String NAME = "cohere_embeddings_task_settings"; - static final String INPUT_TYPE = "input_type"; - static final String EMBEDDING_TYPES = "embedding_types"; - - public static CohereEmbeddingsTaskSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); - String inputType = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException); - List embeddingTypes = extractOptionalListOfType( - map, - EMBEDDING_TYPES, - ModelConfigurations.TASK_SETTINGS, - String.class, - validationException - ); - CohereTruncation truncation = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new CohereEmbeddingsTaskSettings(model, user); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java new file mode 100644 index 0000000000000..d92ea419faef0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; + +public abstract class CohereModel extends Model { + public CohereModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected CohereModel(CohereModel model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + protected CohereModel(CohereModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public abstract ExecutableAction accept(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java new file mode 100644 index 0000000000000..bdd5e981b0490 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.common.SimilarityMeasure; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class CohereServiceSettings implements ServiceSettings { + public static final String NAME = "cohere_service_settings"; + + public static CohereServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + // Throw if any of the settings were empty strings or invalid + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + // the url is optional and only for testing + if (url == null) { + return new CohereServiceSettings((URI) null, similarity, dims, maxInputTokens); + } + + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CohereServiceSettings(uri, similarity, dims, maxInputTokens); + } + + private final URI uri; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + + public CohereServiceSettings( + @Nullable URI uri, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this.uri = uri; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + } + + public CohereServiceSettings( + @Nullable String url, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this(createOptionalUri(url), similarity, dimensions, maxInputTokens); + } + + public CohereServiceSettings(StreamInput in) throws IOException { + uri = createOptionalUri(in.readOptionalString()); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + dimensions = in.readOptionalVInt(); + maxInputTokens = in.readOptionalVInt(); + } + + public URI uri() { + return uri; + } + + public SimilarityMeasure similarity() { + return similarity; + } + + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (uri != null) { + builder.field(URL, uri.toString()); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + builder.endObject(); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + var uriToWrite = uri != null ? uri.toString() : null; + out.writeOptionalString(uriToWrite); + out.writeOptionalEnum(similarity); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CohereServiceSettings that = (CohereServiceSettings) o; + return Objects.equals(uri, that.uri) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(uri, similarity, dimensions, maxInputTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java index ebf1d349e0b7a..3fec21a9e03b0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java @@ -48,11 +48,11 @@ public static CohereTruncation fromString(String name) { } public static CohereTruncation fromStream(StreamInput in) throws IOException { - return in.readEnum(CohereTruncation.class); + return in.readOptionalEnum(CohereTruncation.class); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeEnum(this); + out.writeOptionalEnum(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java new file mode 100644 index 0000000000000..ed1295c8ee31c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class CohereEmbeddingsModel extends CohereModel { + public CohereEmbeddingsModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + this( + modelId, + taskType, + service, + CohereServiceSettings.fromMap(serviceSettings), + CohereEmbeddingsTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + CohereEmbeddingsModel( + String modelId, + TaskType taskType, + String service, + CohereServiceSettings serviceSettings, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); + } + + private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + + private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public CohereServiceSettings getServiceSettings() { + return (CohereServiceSettings) super.getServiceSettings(); + } + + @Override + public CohereEmbeddingsTaskSettings getTaskSettings() { + return (CohereEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..c2ad8b2b8e958 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE; + +/** + * Defines the task settings for the cohere text embeddings service. + * + *

+ * See api docs for details. + *

+ * + * @param model the id of the model to use in the requests to cohere + * @param inputType Specifies the type of input you're giving to the model + * @param embeddingTypes Specifies the types of embeddings you want to get back + * @param truncation Specifies how the API will handle inputs longer than the maximum token length + */ +public record CohereEmbeddingsTaskSettings( + @Nullable String model, + @Nullable InputType inputType, + @Nullable List embeddingTypes, + @Nullable CohereTruncation truncation +) implements TaskSettings { + + public static final String NAME = "cohere_embeddings_task_settings"; + static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null); + static final String INPUT_TYPE = "input_type"; + static final String EMBEDDING_TYPES = "embedding_types"; + + public static CohereEmbeddingsTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + InputType.values(), + validationException + ); + List embeddingTypes = extractOptionalListOfType( + map, + EMBEDDING_TYPES, + ModelConfigurations.TASK_SETTINGS, + String.class, + validationException + ); + CohereTruncation truncation = extractOptionalEnum( + map, + TRUNCATE, + ModelConfigurations.TASK_SETTINGS, + CohereTruncation::fromString, + CohereTruncation.values(), + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation); + } + + public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalString(), InputType.fromStream(in), in.readOptionalStringCollectionAsList(), CohereTruncation.fromStream(in)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (model != null) { + builder.field(MODEL, model); + } + + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + + if (embeddingTypes != null) { + builder.field(EMBEDDING_TYPES, embeddingTypes); + } + + if (truncation != null) { + builder.field(TRUNCATE, truncation); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(model); + inputType.writeTo(out); + out.writeOptionalStringCollection(embeddingTypes); + truncation.writeTo(out); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java index 97823e3bc9079..1e158725f531d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java @@ -10,6 +10,8 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; @@ -21,5 +23,13 @@ public OpenAiModel(ModelConfigurations configurations, ModelSecrets secrets) { super(configurations, secrets); } + protected OpenAiModel(OpenAiModel model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + protected OpenAiModel(OpenAiModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + public abstract ExecutableAction accept(OpenAiActionVisitor creator, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java index 5ade2aad0acb4..553a5eaf60dae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; @@ -100,14 +100,6 @@ public OpenAiServiceSettings( this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens); } - private static URI createOptionalUri(String url) { - if (url == null) { - return null; - } - - return createUri(url); - } - public OpenAiServiceSettings(StreamInput in) throws IOException { uri = createOptionalUri(in.readOptionalString()); organizationId = in.readOptionalString(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 250837d895590..0df3d126402c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -52,29 +52,11 @@ public OpenAiEmbeddingsModel( } private OpenAiEmbeddingsModel(OpenAiEmbeddingsModel originalModel, OpenAiEmbeddingsTaskSettings taskSettings) { - super( - new ModelConfigurations( - originalModel.getConfigurations().getModelId(), - originalModel.getConfigurations().getTaskType(), - originalModel.getConfigurations().getService(), - originalModel.getServiceSettings(), - taskSettings - ), - new ModelSecrets(originalModel.getSecretSettings()) - ); + super(originalModel, taskSettings); } public OpenAiEmbeddingsModel(OpenAiEmbeddingsModel originalModel, OpenAiServiceSettings serviceSettings) { - super( - new ModelConfigurations( - originalModel.getConfigurations().getModelId(), - originalModel.getConfigurations().getTaskType(), - originalModel.getConfigurations().getService(), - serviceSettings, - originalModel.getTaskSettings() - ), - new ModelSecrets(originalModel.getSecretSettings()) - ); + super(originalModel, serviceSettings); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java new file mode 100644 index 0000000000000..21558eec87302 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.common.SimilarityMeasure; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static CohereServiceSettings createRandomWithNonNullUrl() { + return createRandom(randomAlphaOfLength(15)); + } + + /** + * The created settings can have a url set to null. + */ + public static CohereServiceSettings createRandom() { + var url = randomBoolean() ? randomAlphaOfLength(15) : null; + return createRandom(url); + } + + private static CohereServiceSettings createRandom(String url) { + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + var isTextEmbeddingModel = randomBoolean(); + if (isTextEmbeddingModel) { + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1536; + } + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + return new CohereServiceSettings(ServiceUtils.createUri(url), similarityMeasure, dims, maxInputTokens); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var serviceSettings = CohereServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ) + ); + + MatcherAssert.assertThat( + serviceSettings, + is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens)) + ); + } + + public void testFromMap_MissingUrl_DoesNotThrowException() { + var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of())); + assertNull(serviceSettings.uri()); + } + + public void testFromMap_EmptyUrl_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, ""))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + ServiceFields.URL + ) + ) + ); + } + + public void testFromMap_InvalidUrl_ThrowsError() { + var url = "https://www.abc^.com"; + var thrownException = expectThrows( + ValidationException.class, + () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.SIMILARITY, similarity))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Validation Failed: 1: [service_settings] Unknown similarity measure [by_size];") + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereServiceSettings::new; + } + + @Override + protected CohereServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) throws IOException { + return createRandomWithNonNullUrl(); + } + + public static Map getServiceSettingsMap(@Nullable String url) { + var map = new HashMap(); + + if (url != null) { + map.put(ServiceFields.URL, url); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java new file mode 100644 index 0000000000000..be3f77251fd08 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.common.SimilarityMeasure; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class CohereEmbeddingsModelTests extends ESTestCase { + + public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) { + return new CohereEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit), + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) { + return new CohereEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit), + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..b5e44b88b9b4c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static CohereEmbeddingsTaskSettings createRandom() { + var model = randomBoolean() ? randomAlphaOfLength(15) : null; + var inputType = randomBoolean() ? randomFrom(InputType.values()) : null; + var embeddingTypes = randomBoolean() ? List.of(randomAlphaOfLength(6)) : null; + var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; + + return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation); + } + + public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { + MatcherAssert.assertThat( + CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new CohereEmbeddingsTaskSettings(null, null, null, null)) + ); + } + + public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + MatcherAssert.assertThat( + CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>( + Map.of( + CohereServiceFields.MODEL, + "abc", + CohereEmbeddingsTaskSettings.INPUT_TYPE, + InputType.INGEST.toString().toLowerCase(Locale.ROOT), + CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, + List.of("abc", "123"), + CohereServiceFields.TRUNCATE, + CohereTruncation.END.toString().toLowerCase(Locale.ROOT) + ) + ) + ), + is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, List.of("abc", "123"), CohereTruncation.END)) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.INPUT_TYPE, "abc"))) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is("Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [ingest, search];") + ); + } + + public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { + var exception = expectThrows( + ValidationException.class, + () -> CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc", 123))) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [task_settings] Invalid type [Integer]" + + " received for value [123]. [embedding_types] must be type [String];" + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereEmbeddingsTaskSettings::new; + } + + @Override + protected CohereEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettings instance) throws IOException { + return null; + } + + public static Map getTaskSettingsMap() { + return new HashMap<>(Collections.emptyMap()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java index d33ec12016cad..f297eb622c421 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.HashMap; @@ -39,7 +40,7 @@ public void testFromMap_MissingModel_ThrowException() { () -> OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.USER, "user"))) ); - assertThat( + MatcherAssert.assertThat( thrownException.getMessage(), is( Strings.format( @@ -55,14 +56,14 @@ public void testFromMap_CreatesWithModelAndUser() { new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model", OpenAiEmbeddingsTaskSettings.USER, "user")) ); - assertThat(taskSettings.model(), is("model")); - assertThat(taskSettings.user(), is("user")); + MatcherAssert.assertThat(taskSettings.model(), is("model")); + MatcherAssert.assertThat(taskSettings.user(), is("user")); } public void testFromMap_MissingUser_DoesNotThrowException() { var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model"))); - assertThat(taskSettings.model(), is("model")); + MatcherAssert.assertThat(taskSettings.model(), is("model")); assertNull(taskSettings.user()); } @@ -72,7 +73,7 @@ public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { ); var overriddenTaskSettings = taskSettings.overrideWith(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS); - assertThat(overriddenTaskSettings, is(taskSettings)); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); } public void testOverrideWith_UsesOverriddenSettings() { @@ -85,7 +86,7 @@ public void testOverrideWith_UsesOverriddenSettings() { ); var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); - assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user2"))); + MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user2"))); } public void testOverrideWith_UsesOnlyNonNullModelSetting() { @@ -98,7 +99,7 @@ public void testOverrideWith_UsesOnlyNonNullModelSetting() { ); var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); - assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user"))); + MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user"))); } @Override From f94c32c66a988ea15edee6e7aca64bc312aea107 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 17 Jan 2024 09:54:41 -0500 Subject: [PATCH 03/13] Filling out the embedding types --- .../core/inference/results/ByteValue.java | 64 +++++++++++ .../inference/results/EmbeddingValue.java | 15 +++ .../core/inference/results/FloatValue.java | 64 +++++++++++ .../results/TextEmbeddingResults.java | 59 ++++++++-- .../InferenceNamedWriteablesProvider.java | 5 + ...nCreator.java => CohereActionCreator.java} | 8 +- ...nVisitor.java => CohereActionVisitor.java} | 6 +- .../action/cohere/CohereEmbeddingsAction.java | 15 +-- .../cohere/CohereEmbeddingsRequest.java | 31 +++--- .../cohere/CohereEmbeddingsRequestEntity.java | 2 +- .../CohereEmbeddingsResponseEntity.java | 104 ++++++++++++++++-- .../HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../OpenAiEmbeddingsResponseEntity.java | 2 +- .../inference/services/ServiceUtils.java | 54 ++++++++- .../embeddings/CohereEmbeddingType.java | 57 ++++++++++ .../embeddings/CohereEmbeddingsModel.java | 9 ++ .../CohereEmbeddingsTaskSettings.java | 27 ++++- .../CohereEmbeddingsModelTests.java | 50 +++++++-- .../CohereEmbeddingsTaskSettingsTests.java | 68 ++++++++++-- 19 files changed, 562 insertions(+), 80 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/{OpenAiActionCreator.java => CohereActionCreator.java} (80%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/{OpenAiActionVisitor.java => CohereActionVisitor.java} (69%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java new file mode 100644 index 0000000000000..9e3dc866d304e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class ByteValue implements EmbeddingValue { + + public static final String NAME = "byte_value"; + + private final Byte value; + + public ByteValue(Byte value) { + this.value = value; + } + + public ByteValue(StreamInput in) throws IOException { + value = in.readByte(); + } + + @Override + public Byte getValue() { + return value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.value(value); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ByteValue that = (ByteValue) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeByte(value); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java new file mode 100644 index 0000000000000..f3fd641db395d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xcontent.ToXContentFragment; + +public interface EmbeddingValue extends NamedWriteable, ToXContentFragment { + Number getValue(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java new file mode 100644 index 0000000000000..9ead125d2d48a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class FloatValue implements EmbeddingValue { + + public static final String NAME = "float_value"; + + private final Float value; + + public FloatValue(Float value) { + this.value = value; + } + + public FloatValue(StreamInput in) throws IOException { + value = in.readFloat(); + } + + @Override + public Number getValue() { + return value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.value(value); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FloatValue that = (FloatValue) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(value); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index ace5974866038..9fc99ba3f03f8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.inference.results; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,6 +22,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; /** @@ -51,10 +53,7 @@ public TextEmbeddingResults(StreamInput in) throws IOException { @SuppressWarnings("deprecation") TextEmbeddingResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) { this( - legacyTextEmbeddingResults.embeddings() - .stream() - .map(embedding -> new Embedding(embedding.values())) - .collect(Collectors.toList()) + legacyTextEmbeddingResults.embeddings().stream().map(embedding -> Embedding.of(embedding.values())).collect(Collectors.toList()) ); } @@ -81,7 +80,7 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { return embeddings.stream() - .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) + .map(embedding -> embedding.values.stream().mapToDouble(value -> value.getValue().doubleValue()).toArray()) .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) .toList(); } @@ -90,7 +89,7 @@ public List transformToCoordinationFormat() { @SuppressWarnings("deprecation") public List transformToLegacyFormat() { var legacyEmbedding = new LegacyTextEmbeddingResults( - embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList() + embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList() ); return List.of(legacyEmbedding); @@ -103,16 +102,43 @@ public Map asMap() { return map; } - public record Embedding(List values) implements Writeable, ToXContentObject { + public static class Embedding implements Writeable, ToXContentObject { public static final String EMBEDDING = "embedding"; + public static Embedding of(List values) { + return new Embedding(convertFloatsToEmbeddingValues(values)); + } + + private final List values; + + public Embedding(List values) { + this.values = values; + } + public Embedding(StreamInput in) throws IOException { - this(in.readCollectionAsImmutableList(StreamInput::readFloat)); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { + values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class); + } else { + values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat)); + } + } + + private static List convertFloatsToEmbeddingValues(List floats) { + return floats.stream().map(FloatValue::new).collect(Collectors.toList()); + } + + public List toFloats() { + return values.stream().map(value -> value.getValue().floatValue()).toList(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(values, StreamOutput::writeFloat); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { + out.writeNamedWriteableCollection(values); + } else { + // TODO do we need to check that the values are floats here? + out.writeCollection(toFloats(), StreamOutput::writeFloat); + } } @Override @@ -120,7 +146,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.startArray(EMBEDDING); - for (Float value : values) { + for (EmbeddingValue value : values) { builder.value(value); } builder.endArray(); @@ -134,6 +160,19 @@ public String toString() { return Strings.toString(this); } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Embedding embedding = (Embedding) o; + return Objects.equals(values, embedding.values); + } + + @Override + public int hashCode() { + return Objects.hash(values); + } + public Map asMap() { return Map.of(EMBEDDING, values); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index c632c568fea16..a6d6c3cbee4d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -14,6 +14,9 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.core.inference.results.ByteValue; +import org.elasticsearch.xpack.core.inference.results.EmbeddingValue; +import org.elasticsearch.xpack.core.inference.results.FloatValue; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; @@ -49,6 +52,8 @@ public static List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new) ); + namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, FloatValue.NAME, FloatValue::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, ByteValue.NAME, ByteValue::new)); // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index 0353922959e0b..b3ac96979b68b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -10,7 +10,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import java.util.Map; import java.util.Objects; @@ -18,17 +18,17 @@ /** * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. */ -public class OpenAiActionCreator implements OpenAiActionVisitor { +public class CohereActionCreator implements CohereActionVisitor { private final Sender sender; private final ServiceComponents serviceComponents; - public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } @Override - public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) { + public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) { var overriddenModel = model.overrideWith(taskSettings); return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java similarity index 69% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java index a3a4cfbbfc873..1500d48e3c201 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java @@ -8,10 +8,10 @@ package org.elasticsearch.xpack.inference.external.action.cohere; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import java.util.Map; -public interface OpenAiActionVisitor { - ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings); +public interface CohereActionVisitor { + ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index 569f33b256803..cc84b06a7ba8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; import org.elasticsearch.xpack.inference.external.cohere.CohereResponseHandler; @@ -22,7 +21,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.cohere.CohereEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.external.response.cohere.CohereEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; @@ -31,24 +30,22 @@ import java.util.Objects; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.common.Truncator.truncate; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; public class CohereEmbeddingsAction implements ExecutableAction { private static final Logger logger = LogManager.getLogger(CohereEmbeddingsAction.class); + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); private final CohereAccount account; private final CohereEmbeddingsModel model; private final String errorMessage; - private final Truncator truncator; private final RetryingHttpSender sender; public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) { this.model = Objects.requireNonNull(model); this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey()); this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri()); - this.truncator = Objects.requireNonNull(serviceComponents.truncator()); this.sender = new RetryingHttpSender( Objects.requireNonNull(sender), serviceComponents.throttlerManager(), @@ -70,12 +67,12 @@ private static String getErrorMessage(@Nullable URI uri) { public void execute(List input, ActionListener listener) { try { // TODO only truncate if the setting is NONE? - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + // var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(truncator, account, truncatedInput, model.getTaskSettings()); + CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings()); ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); - sender.send(request, wrappedListener); + sender.send(request, HANDLER, wrappedListener); } catch (ElasticsearchException e) { listener.onFailure(e); } catch (Exception e) { @@ -84,6 +81,6 @@ public void execute(List input, ActionListener } private static ResponseHandler createEmbeddingsHandler() { - return new CohereResponseHandler("cohere text embedding", OpenAiEmbeddingsResponseEntity::fromResponse); + return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java index 7f2d34e1f8b3b..ba0e3d258dba9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java @@ -14,7 +14,6 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; @@ -22,6 +21,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; @@ -29,21 +29,14 @@ public class CohereEmbeddingsRequest implements Request { - private final Truncator truncator; private final CohereAccount account; - private final Truncator.TruncationResult truncationResult; + private final List input; private final URI uri; private final CohereEmbeddingsTaskSettings taskSettings; - public CohereEmbeddingsRequest( - Truncator truncator, - CohereAccount account, - Truncator.TruncationResult input, - CohereEmbeddingsTaskSettings taskSettings - ) { - this.truncator = Objects.requireNonNull(truncator); + public CohereEmbeddingsRequest(CohereAccount account, List input, CohereEmbeddingsTaskSettings taskSettings) { this.account = Objects.requireNonNull(account); - this.truncationResult = Objects.requireNonNull(input); + this.input = Objects.requireNonNull(input); this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri); this.taskSettings = Objects.requireNonNull(taskSettings); } @@ -53,7 +46,7 @@ public HttpRequestBase createRequest() { HttpPost httpPost = new HttpPost(uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereEmbeddingsRequestEntity(truncationResult.input(), taskSettings)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -70,15 +63,21 @@ public URI getURI() { @Override public Request truncate() { - // TODO only do this is the truncate setting is NONE? - var truncatedInput = truncator.truncate(truncationResult.input()); - return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings); + return this; + // TODO only do this is the truncate setting is NONE? + // var truncatedInput = truncator.truncate(truncationResult.input()); + // + // return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings); } @Override public boolean[] getTruncationInfo() { - return truncationResult.truncated().clone(); + return null; + } + + public CohereEmbeddingsTaskSettings getTaskSettings() { + return taskSettings; } // default for testing diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java index 8aea434d7339d..7331426b5c683 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -45,7 +45,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } if (taskSettings.truncation() != null) { - builder.field(EMBEDDING_TYPES_FIELD, taskSettings.truncation()); + builder.field(CohereServiceFields.TRUNCATE, taskSettings.truncation()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index cbe70ad8d919a..8f0eabeda3a04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -9,23 +9,38 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ByteValue; +import org.elasticsearch.xpack.core.inference.results.EmbeddingValue; +import org.elasticsearch.xpack.core.inference.results.FloatValue; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import java.io.IOException; import java.util.List; +import java.util.Map; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType.toLowerCase; public class CohereEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response"; + private static final Map> EMBEDDING_PARSERS = Map.of( + toLowerCase(CohereEmbeddingType.FLOAT), + CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry, + toLowerCase(CohereEmbeddingType.INT8), + CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry + ); + /** * Parses the OpenAI json response. * For a request like: @@ -81,23 +96,77 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp XContentParser.Token token = jsonParser.currentToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + // TODO can't really rely on the texts field being before embeddings, this will need to be a loop + // we don't return the is_truncated result for text embeddings so we don't really need this yet + positionParserAtTokenAfterField(jsonParser, "texts", FAILED_TO_FIND_FIELD_TEMPLATE); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, jsonParser.currentToken(), jsonParser); - List embeddingList = XContentParserUtils.parseList( + List inputAfterTruncation = XContentParserUtils.parseList( jsonParser, - CohereEmbeddingsResponseEntity::parseEmbeddingObject + CohereEmbeddingsResponseEntity::parseTruncatedTextsField ); - return new TextEmbeddingResults(embeddingList); + positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + // TODO we don't need this yet, just require that the embedding type be a single string so that + // this is always the second style + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return parseEmbeddingsObject(jsonParser); + } else if (token == XContentParser.Token.START_ARRAY) { + List embeddingList = XContentParserUtils.parseList( + jsonParser, + parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray( + parser, + CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry + ) + ); + + return new TextEmbeddingResults(embeddingList); + } else { + throwUnknownToken(token, jsonParser); + } + + // This should never be reached. The above code should either return successfully or hit the throwUnknownToken + // or throw a parsing exception + throw new IllegalStateException("Reached an invalid state while parsing the Cohere response"); } } - private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + private static String parseTruncatedTextsField(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + return parser.text(); + } + + private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException { + XContentParser.Token token; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + var embeddingValueParser = EMBEDDING_PARSERS.get(parser.currentName()); + if (embeddingValueParser == null) { + continue; + } + + parser.nextToken(); + var embeddingList = XContentParserUtils.parseList( + parser, + listParser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(listParser, embeddingValueParser) + ); + + return new TextEmbeddingResults(embeddingList); + } + } + + throw new IllegalStateException("Failed to find a supported embedding type"); + } - positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + private static TextEmbeddingResults.Embedding parseEmbeddingsArray( + XContentParser parser, + CheckedFunction parseEntry + ) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingList); + List embeddingValues = XContentParserUtils.parseList(parser, parseEntry); // the parser is currently sitting at an ARRAY_END so go to the next token parser.nextToken(); @@ -107,10 +176,25 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse return new TextEmbeddingResults.Embedding(embeddingValues); } - private static float parseEmbeddingList(XContentParser parser) throws IOException { + private static FloatValue parseEmbeddingFloatEntry(XContentParser parser) throws IOException { XContentParser.Token token = parser.currentToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); - return parser.floatValue(); + return new FloatValue(parser.floatValue()); + } + + private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + var parsedByte = parser.shortValue(); + checkByteBounds(parsedByte); + + return new ByteValue((byte) parsedByte); + } + + private static void checkByteBounds(short value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } } private CohereEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index b74b03891034f..9bd01bac7bee1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList); - return new TextEmbeddingResults.Embedding(embeddingValues); + return TextEmbeddingResults.Embedding.of(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 4926ba3f0ef6b..a5fe46ab1de5f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array parser.skipChildren(); - return new TextEmbeddingResults.Embedding(embeddingValues); + return TextEmbeddingResults.Embedding.of(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 6fef9dac6095a..37119b872bf12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -109,6 +109,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } + public static String mustBeNonEmptyList(String settingName, String scope) { + return Strings.format("[%s] Invalid value empty list. [%s] must be a non-empty list", scope, settingName); + } + public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) { return Strings.format( "[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]", @@ -191,6 +195,54 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin return null; } + public static List extractOptionalListOfEnums( + Map map, + String settingName, + String scope, + CheckedFunction converter, + T[] validTypes, + ValidationException validationException + ) { + List listField = ServiceUtils.removeAsType(map, settingName, List.class); + if (listField == null) { + return null; + } + + if (listField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope)); + return null; + } + + var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + + List castedList = new ArrayList<>(listField.size()); + + for (Object listEntry : listField) { + if (listEntry instanceof String == false) { + validationException.addValidationError( + invalidType( + settingName, + scope, + listEntry.getClass().getSimpleName(), + listEntry.toString(), + String.class.getSimpleName() + ) + ); + return null; + } + + var stringEntry = (String) listEntry; + try { + castedList.add(converter.apply(stringEntry)); + } catch (IllegalArgumentException e) { + validationException.addValidationError(invalidValue(settingName, scope, stringEntry, validTypesAsStrings)); + return null; + } + } + + return castedList; + } + @SuppressWarnings("unchecked") public static List extractOptionalListOfType( Map map, @@ -206,7 +258,7 @@ public static List extractOptionalListOfType( } if (listField.isEmpty()) { - validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope)); + validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope)); return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java new file mode 100644 index 0000000000000..0e4f9ffeba466 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +/** + * Defines the type of embedding that the cohere api should return for a request. + * + *

+ * See api docs for details. + *

+ */ +public enum CohereEmbeddingType implements Writeable { + /** + * Use this when you want to get back the default float embeddings. Valid for all models. + */ + FLOAT, + /** + * Use this when you want to get back signed int8 embeddings. Valid for only v3 models. + */ + INT8; + + public static String NAME = "cohere_embedding_type"; + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public static String toLowerCase(CohereEmbeddingType type) { + return type.toString().toLowerCase(Locale.ROOT); + } + + public static CohereEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static CohereEmbeddingType fromStream(StreamInput in) throws IOException { + return in.readEnum(CohereEmbeddingType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index ed1295c8ee31c..816e7bc6933cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -76,4 +76,13 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept() { return null; } + + public CohereEmbeddingsModel overrideWith(Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return this; + } + + var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings); + return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index c2ad8b2b8e958..849290cfaa66d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -24,7 +24,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfEnums; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE; @@ -44,7 +44,7 @@ public record CohereEmbeddingsTaskSettings( @Nullable String model, @Nullable InputType inputType, - @Nullable List embeddingTypes, + @Nullable List embeddingTypes, @Nullable CohereTruncation truncation ) implements TaskSettings { @@ -69,11 +69,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { InputType.values(), validationException ); - List embeddingTypes = extractOptionalListOfType( + List embeddingTypes = extractOptionalListOfEnums( map, EMBEDDING_TYPES, ModelConfigurations.TASK_SETTINGS, - String.class, + CohereEmbeddingType::fromString, + CohereEmbeddingType.values(), validationException ); CohereTruncation truncation = extractOptionalEnum( @@ -93,7 +94,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { } public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalString(), InputType.fromStream(in), in.readOptionalStringCollectionAsList(), CohereTruncation.fromStream(in)); + this( + in.readOptionalString(), + InputType.fromStream(in), + in.readOptionalCollectionAsList(CohereEmbeddingType::fromStream), + CohereTruncation.fromStream(in) + ); } @Override @@ -132,7 +138,16 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(model); inputType.writeTo(out); - out.writeOptionalStringCollection(embeddingTypes); + out.writeOptionalCollection(embeddingTypes); truncation.writeTo(out); } + + public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { + var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model(); + var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType(); + var embeddingTypesToUse = requestTaskSettings.embeddingTypes() == null ? embeddingTypes : requestTaskSettings.embeddingTypes(); + var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation(); + + return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index be3f77251fd08..d6b01486ae3ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -14,27 +14,59 @@ import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; public class CohereEmbeddingsModelTests extends ESTestCase { + public void testOverrideWith_OverridesModel() { + var model = createModel("url", "api_key", null); + + var overriddenModel = model.overrideWith(getTaskSettingsMap("model", null, null, null)); + var expectedModel = createModel("url", "api_key", new CohereEmbeddingsTaskSettings("model", null, null, null), null, null); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotOverride_WhenSettingsAreEmpty() { + var model = createModel("url", "api_key", null); + + var overriddenModel = model.overrideWith(Map.of()); + MatcherAssert.assertThat(overriddenModel, sameInstance(model)); + } + + public void testOverrideWith_DoesNotOverride_WhenSettingsAreNull() { + var model = createModel("url", "api_key", null); + + var overriddenModel = model.overrideWith(null); + MatcherAssert.assertThat(overriddenModel, sameInstance(model)); + } + public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) { - return new CohereEmbeddingsModel( - "id", - TaskType.TEXT_EMBEDDING, - "service", - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit), - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) - ); + return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null); } public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) { + return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions); + } + + public static CohereEmbeddingsModel createModel( + String url, + String apiKey, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions + ) { return new CohereEmbeddingsModel( "id", TaskType.TEXT_EMBEDDING, "service", new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit), - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + taskSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index b5e44b88b9b4c..f2ea1a8c4f3fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; @@ -16,7 +17,6 @@ import org.hamcrest.MatcherAssert; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -29,7 +29,7 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe public static CohereEmbeddingsTaskSettings createRandom() { var model = randomBoolean() ? randomAlphaOfLength(15) : null; var inputType = randomBoolean() ? randomFrom(InputType.values()) : null; - var embeddingTypes = randomBoolean() ? List.of(randomAlphaOfLength(6)) : null; + var embeddingTypes = randomBoolean() ? List.of(randomFrom(CohereEmbeddingType.values())) : null; var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation); @@ -52,13 +52,20 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { CohereEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString().toLowerCase(Locale.ROOT), CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, - List.of("abc", "123"), + List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8), CohereServiceFields.TRUNCATE, CohereTruncation.END.toString().toLowerCase(Locale.ROOT) ) ) ), - is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, List.of("abc", "123"), CohereTruncation.END)) + is( + new CohereEmbeddingsTaskSettings( + "abc", + InputType.INGEST, + List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8), + CohereTruncation.END + ) + ) ); } @@ -77,9 +84,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { var exception = expectThrows( ValidationException.class, - () -> CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc", 123))) - ) + () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc")))) ); MatcherAssert.assertThat( @@ -91,6 +96,28 @@ public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { ); } + public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { + var taskSettings = CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) + ); + + var overriddenTaskSettings = taskSettings.overrideWith(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOverrideWith_UsesOverriddenSettings() { + var taskSettings = CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) + ); + + var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.START.toString())) + ); + + var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.START))); + } + @Override protected Writeable.Reader instanceReader() { return CohereEmbeddingsTaskSettings::new; @@ -106,7 +133,30 @@ protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettin return null; } - public static Map getTaskSettingsMap() { - return new HashMap<>(Collections.emptyMap()); + public static Map getTaskSettingsMap( + @Nullable String model, + @Nullable InputType inputType, + @Nullable List embeddingTypes, + @Nullable CohereTruncation truncation + ) { + var map = new HashMap(); + + if (model != null) { + map.put(CohereServiceFields.MODEL, model); + } + + if (inputType != null) { + map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType); + } + + if (embeddingTypes != null) { + map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, embeddingTypes); + } + + if (truncation != null) { + map.put(CohereServiceFields.TRUNCATE, truncation); + } + + return map; } } From 58c707d4ad4a85d5a68c97e5095c2d03f21a1222 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 18 Jan 2024 13:26:32 -0500 Subject: [PATCH 04/13] Working cohere --- .../core/inference/results/ByteValue.java | 5 + .../core/inference/results/FloatValue.java | 5 + .../results/TextEmbeddingResults.java | 18 +- .../InferenceNamedWriteablesProvider.java | 10 + .../xpack/inference/InferencePlugin.java | 4 +- .../external/action/ActionUtils.java | 12 + .../action/cohere/CohereEmbeddingsAction.java | 17 +- .../action/openai/OpenAiEmbeddingsAction.java | 14 +- .../cohere/CohereEmbeddingsRequest.java | 9 - .../cohere/CohereEmbeddingsRequestEntity.java | 29 +- .../CohereEmbeddingsResponseEntity.java | 75 +- .../cohere/CohereErrorResponseEntity.java | 1 + .../HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../OpenAiEmbeddingsResponseEntity.java | 2 +- .../inference/services/ServiceUtils.java | 85 -- .../services/cohere/CohereModel.java | 5 +- .../services/cohere/CohereService.java | 174 ++++ .../embeddings/CohereEmbeddingType.java | 4 +- .../embeddings/CohereEmbeddingsModel.java | 7 +- .../CohereEmbeddingsTaskSettings.java | 29 +- .../cohere/CohereActionCreatorTests.java | 152 ++++ .../cohere/CohereEmbeddingsActionTests.java | 333 +++++++ .../HuggingFaceActionCreatorTests.java | 6 +- .../openai/OpenAiActionCreatorTests.java | 10 +- .../openai/OpenAiEmbeddingsActionTests.java | 4 +- .../cohere/CohereResponseHandlerTests.java | 165 ++++ .../external/openai/OpenAiClientTests.java | 8 +- .../CohereEmbeddingsRequestEntityTests.java | 65 ++ .../cohere/CohereEmbeddingsRequestTests.java | 156 ++++ .../CohereEmbeddingsResponseEntityTests.java | 434 +++++++++ .../CohereErrorResponseEntityTests.java | 50 + ...gingFaceEmbeddingsResponseEntityTests.java | 20 +- .../OpenAiEmbeddingsResponseEntityTests.java | 10 +- .../inference/results/ByteValueTests.java | 35 + .../inference/results/FloatValueTests.java | 35 + .../results/TextEmbeddingResultsTests.java | 222 ++++- .../services/cohere/CohereServiceTests.java | 853 ++++++++++++++++++ .../cohere/CohereTruncationTests.java | 34 + .../embeddings/CohereEmbeddingTypeTests.java | 34 + .../CohereEmbeddingsTaskSettingsTests.java | 33 +- .../huggingface/HuggingFaceServiceTests.java | 4 +- .../services/openai/OpenAiServiceTests.java | 4 +- 42 files changed, 2910 insertions(+), 264 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java index 9e3dc866d304e..5017aa9be0d52 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java @@ -28,6 +28,11 @@ public ByteValue(StreamInput in) throws IOException { value = in.readByte(); } + @Override + public String toString() { + return value.toString(); + } + @Override public Byte getValue() { return value; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java index 9ead125d2d48a..5c89720e54b2e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java @@ -28,6 +28,11 @@ public FloatValue(StreamInput in) throws IOException { value = in.readFloat(); } + @Override + public String toString() { + return value.toString(); + } + @Override public Number getValue() { return value; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 9fc99ba3f03f8..fadffb2782018 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -53,7 +53,10 @@ public TextEmbeddingResults(StreamInput in) throws IOException { @SuppressWarnings("deprecation") TextEmbeddingResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) { this( - legacyTextEmbeddingResults.embeddings().stream().map(embedding -> Embedding.of(embedding.values())).collect(Collectors.toList()) + legacyTextEmbeddingResults.embeddings() + .stream() + .map(embedding -> Embedding.ofFloats(embedding.values())) + .collect(Collectors.toList()) ); } @@ -105,10 +108,16 @@ public Map asMap() { public static class Embedding implements Writeable, ToXContentObject { public static final String EMBEDDING = "embedding"; - public static Embedding of(List values) { + public static Embedding ofFloats(List values) { return new Embedding(convertFloatsToEmbeddingValues(values)); } + public static Embedding ofBytes(List values) { + List convertedValues = values.stream().map(ByteValue::new).collect(Collectors.toList()); + + return new Embedding(convertedValues); + } + private final List values; public Embedding(List values) { @@ -131,12 +140,15 @@ public List toFloats() { return values.stream().map(value -> value.getValue().floatValue()).toList(); } + public List values() { + return values; + } + @Override public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { out.writeNamedWriteableCollection(values); } else { - // TODO do we need to check that the values are floats here? out.writeCollection(toFloats(), StreamOutput::writeFloat); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index a6d6c3cbee4d5..02d19ba60b0e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -20,6 +20,8 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; @@ -92,6 +94,14 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(TaskSettings.class, OpenAiEmbeddingsTaskSettings.NAME, OpenAiEmbeddingsTaskSettings::new) ); + // Cohere + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new) + ); + return namedWriteables; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 33d71c65ed643..d5e41f20113cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -54,6 +54,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; @@ -154,7 +155,8 @@ public List getInferenceServiceFactories() { ElserMlNodeService::new, context -> new HuggingFaceElserService(httpFactory, serviceComponents), context -> new HuggingFaceService(httpFactory, serviceComponents), - context -> new OpenAiService(httpFactory, serviceComponents) + context -> new OpenAiService(httpFactory, serviceComponents), + context -> new CohereService(httpFactory, serviceComponents) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java index 856146fafcb45..4a8519934c63c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java @@ -11,9 +11,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; +import java.net.URI; + public class ActionUtils { public static ActionListener wrapFailuresInElasticsearchException( @@ -35,5 +39,13 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e); } + public static String getErrorMessage(@Nullable URI uri, String message) { + if (uri != null) { + return Strings.format("Failed to send %s request to [%s]", message, uri); + } + + return Strings.format("Failed to send %s request", message); + } + private ActionUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index cc84b06a7ba8f..62afbc190d573 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; @@ -25,12 +24,11 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import java.net.URI; import java.util.List; import java.util.Objects; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; public class CohereEmbeddingsAction implements ExecutableAction { @@ -45,7 +43,7 @@ public class CohereEmbeddingsAction implements ExecutableAction { public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) { this.model = Objects.requireNonNull(model); this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey()); - this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri()); + this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "Cohere embeddings"); this.sender = new RetryingHttpSender( Objects.requireNonNull(sender), serviceComponents.throttlerManager(), @@ -55,20 +53,9 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Servic ); } - private static String getErrorMessage(@Nullable URI uri) { - if (uri != null) { - return format("Failed to send Cohere embeddings request to [%s]", uri.toString()); - } - - return "Failed to send Cohere embeddings request"; - } - @Override public void execute(List input, ActionListener listener) { try { - // TODO only truncate if the setting is NONE? - // var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings()); ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java index 20128f1168bb9..417935f8c920c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -20,13 +19,12 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; -import java.net.URI; import java.util.List; import java.util.Objects; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.common.Truncator.truncate; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; public class OpenAiEmbeddingsAction implements ExecutableAction { @@ -45,18 +43,10 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic this.model.getSecretSettings().apiKey() ); this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents)); - this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri()); + this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "send OpenAI embeddings"); this.truncator = Objects.requireNonNull(serviceComponents.truncator()); } - private static String getErrorMessage(@Nullable URI uri) { - if (uri != null) { - return format("Failed to send OpenAI embeddings request to [%s]", uri.toString()); - } - - return "Failed to send OpenAI embeddings request"; - } - @Override public void execute(List input, ActionListener listener) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java index ba0e3d258dba9..277996140001f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java @@ -63,12 +63,7 @@ public URI getURI() { @Override public Request truncate() { - return this; - // TODO only do this is the truncate setting is NONE? - // var truncatedInput = truncator.truncate(truncationResult.input()); - // - // return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings); } @Override @@ -76,10 +71,6 @@ public boolean[] getTruncationInfo() { return null; } - public CohereEmbeddingsTaskSettings getTaskSettings() { - return taskSettings; - } - // default for testing static URI buildDefaultUri() throws URISyntaxException { return new URIBuilder().setScheme("https") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java index 7331426b5c683..831fb07b30e24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; @@ -14,10 +15,22 @@ import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.Objects; public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject { + private static final String SEARCH_DOCUMENT = "search_document"; + private static final String SEARCH_QUERY = "search_query"; + /** + * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request + */ + private static final Map INPUT_TYPE_MAPPING = Map.of( + InputType.INGEST, + SEARCH_DOCUMENT, + InputType.SEARCH, + SEARCH_QUERY + ); private static final String TEXTS_FIELD = "texts"; static final String INPUT_TYPE_FIELD = "input_type"; @@ -37,11 +50,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } if (taskSettings.inputType() != null) { - builder.field(INPUT_TYPE_FIELD, taskSettings.inputType()); + builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType())); } - if (taskSettings.embeddingTypes() != null) { - builder.field(EMBEDDING_TYPES_FIELD, taskSettings.embeddingTypes()); + if (taskSettings.embeddingType() != null) { + builder.field(EMBEDDING_TYPES_FIELD, List.of(taskSettings.embeddingType())); } if (taskSettings.truncation() != null) { @@ -51,4 +64,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + private static String covertToString(InputType inputType) { + var stringValue = INPUT_TYPE_MAPPING.get(inputType); + + if (stringValue == null) { + return SEARCH_DOCUMENT; + } + + return stringValue; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index 8f0eabeda3a04..8f6d6c75d8c71 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.cohere; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.core.CheckedFunction; @@ -40,6 +41,12 @@ public class CohereEmbeddingsResponseEntity { toLowerCase(CohereEmbeddingType.INT8), CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry ); + private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); + + private static String supportedEmbeddingTypes() { + var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new); + return String.join(", ", validTypes); + } /** * Parses the OpenAI json response. @@ -86,6 +93,42 @@ public class CohereEmbeddingsResponseEntity { * } * * + * + * Or this: + * + *
+     * 
+     * {
+     *  "id": "da4f9ea6-37e4-41ab-b5e1-9e2985609555",
+     *  "texts": [
+     *      "hello",
+     *      "awesome"
+     *  ],
+     *  "embeddings": {
+     *      "float": [
+     *          [
+     *              123
+     *          ],
+     *          [
+     *              123
+     *          ],
+     *      ]
+     *  },
+     *  "meta": {
+     *      "api_version": {
+     *          "version": "1"
+     *      },
+     *      "warnings": [
+     *          "default model on embed will be deprecated in the future, please specify a model in the request."
+     *      ],
+     *      "billed_units": {
+     *          "input_tokens": 3
+     *      }
+     *  },
+     *  "response_type": "embeddings_floats"
+     * }
+     * 
+     * 
*/ public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -96,23 +139,13 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp XContentParser.Token token = jsonParser.currentToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - // TODO can't really rely on the texts field being before embeddings, this will need to be a loop - // we don't return the is_truncated result for text embeddings so we don't really need this yet - positionParserAtTokenAfterField(jsonParser, "texts", FAILED_TO_FIND_FIELD_TEMPLATE); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, jsonParser.currentToken(), jsonParser); - - List inputAfterTruncation = XContentParserUtils.parseList( - jsonParser, - CohereEmbeddingsResponseEntity::parseTruncatedTextsField - ); - positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - // TODO we don't need this yet, just require that the embedding type be a single string so that - // this is always the second style + token = jsonParser.currentToken(); if (token == XContentParser.Token.START_OBJECT) { return parseEmbeddingsObject(jsonParser); } else if (token == XContentParser.Token.START_ARRAY) { + // if the request did not specify the embedding types then it will default to floats List embeddingList = XContentParserUtils.parseList( jsonParser, parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray( @@ -132,11 +165,6 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp } } - private static String parseTruncatedTextsField(XContentParser parser) throws IOException { - XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - return parser.text(); - } - private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException { XContentParser.Token token; @@ -157,7 +185,12 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) } } - throw new IllegalStateException("Failed to find a supported embedding type"); + throw new IllegalStateException( + Strings.format( + "Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [%s]", + VALID_EMBEDDING_TYPES_STRING + ) + ); } private static TextEmbeddingResults.Embedding parseEmbeddingsArray( @@ -165,14 +198,8 @@ private static TextEmbeddingResults.Embedding parseEmbeddingsArray( CheckedFunction parseEntry ) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValues = XContentParserUtils.parseList(parser, parseEntry); - // the parser is currently sitting at an ARRAY_END so go to the next token - parser.nextToken(); - // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array - parser.skipChildren(); - return new TextEmbeddingResults.Embedding(embeddingValues); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java index 6f947b45955be..7d1731105e2f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java @@ -22,6 +22,7 @@ private CohereErrorResponseEntity(String errorMessage) { this.errorMessage = errorMessage; } + @Override public String getErrorMessage() { return errorMessage; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index 9bd01bac7bee1..b3706b318439b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList); - return TextEmbeddingResults.Embedding.of(embeddingValues); + return TextEmbeddingResults.Embedding.ofFloats(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index a5fe46ab1de5f..0640821d0b9e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array parser.skipChildren(); - return TextEmbeddingResults.Embedding.of(embeddingValues); + return TextEmbeddingResults.Embedding.ofFloats(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 37119b872bf12..3e9f6e1f75a8a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -22,7 +22,6 @@ import java.net.URI; import java.net.URISyntaxException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; @@ -195,90 +194,6 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin return null; } - public static List extractOptionalListOfEnums( - Map map, - String settingName, - String scope, - CheckedFunction converter, - T[] validTypes, - ValidationException validationException - ) { - List listField = ServiceUtils.removeAsType(map, settingName, List.class); - if (listField == null) { - return null; - } - - if (listField.isEmpty()) { - validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope)); - return null; - } - - var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); - - List castedList = new ArrayList<>(listField.size()); - - for (Object listEntry : listField) { - if (listEntry instanceof String == false) { - validationException.addValidationError( - invalidType( - settingName, - scope, - listEntry.getClass().getSimpleName(), - listEntry.toString(), - String.class.getSimpleName() - ) - ); - return null; - } - - var stringEntry = (String) listEntry; - try { - castedList.add(converter.apply(stringEntry)); - } catch (IllegalArgumentException e) { - validationException.addValidationError(invalidValue(settingName, scope, stringEntry, validTypesAsStrings)); - return null; - } - } - - return castedList; - } - - @SuppressWarnings("unchecked") - public static List extractOptionalListOfType( - Map map, - String settingName, - String scope, - Class type, - ValidationException validationException - ) { - List listField = ServiceUtils.removeAsType(map, settingName, List.class); - - if (listField == null) { - return null; - } - - if (listField.isEmpty()) { - validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope)); - return null; - } - - List castedList = new ArrayList<>(listField.size()); - - for (Object listEntry : listField) { - if (type.isAssignableFrom(listEntry.getClass()) == false) { - // TODO should we just throw here like removeAsType - validationException.addValidationError( - invalidType(settingName, scope, listEntry.getClass().getSimpleName(), listEntry.toString(), type.getSimpleName()) - ); - return null; - } - - castedList.add((T) listEntry); - } - - return castedList; - } - public static String extractRequiredString( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index d92ea419faef0..1b4843e441248 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -13,6 +13,9 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; + +import java.util.Map; public abstract class CohereModel extends Model { public CohereModel(ModelConfigurations configurations, ModelSecrets secrets) { @@ -27,5 +30,5 @@ protected CohereModel(CohereModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); } - public abstract ExecutableAction accept(); + public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java new file mode 100644 index 0000000000000..e654316a4fbd4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.common.SimilarityMeasure; +import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class CohereService extends SenderService { + public static final String NAME = "cohere"; + + public CohereService(SetOnce factory, SetOnce serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public CohereModel parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + CohereModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME) + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + return model; + } + + private static CohereModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new CohereEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public CohereModel parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public CohereModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModel(modelId, taskType, serviceSettingsMap, taskSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME)); + } + + @Override + public void doInfer( + Model model, + List input, + Map taskSettings, + ActionListener listener + ) { + if (model instanceof CohereModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + CohereModel cohereModel = (CohereModel) model; + var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); + + var action = cohereModel.accept(actionCreator, taskSettings); + action.execute(input, listener); + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof CohereEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) { + CohereServiceSettings serviceSettings = new CohereServiceSettings( + model.getServiceSettings().uri(), + SimilarityMeasure.DOT_PRODUCT, + embeddingSize, + model.getServiceSettings().maxInputTokens() + ); + + return new CohereEmbeddingsModel(model, serviceSettings); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java index 0e4f9ffeba466..803382bb8c947 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java @@ -47,11 +47,11 @@ public static CohereEmbeddingType fromString(String name) { } public static CohereEmbeddingType fromStream(StreamInput in) throws IOException { - return in.readEnum(CohereEmbeddingType.class); + return in.readOptionalEnum(CohereEmbeddingType.class); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeEnum(this); + out.writeOptionalEnum(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 816e7bc6933cd..accba9149a46e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -53,7 +54,7 @@ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskS super(model, taskSettings); } - private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) { + public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) { super(model, serviceSettings); } @@ -73,8 +74,8 @@ public DefaultSecretSettings getSecretSettings() { } @Override - public ExecutableAction accept() { - return null; + public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); } public CohereEmbeddingsModel overrideWith(Map taskSettings) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 849290cfaa66d..b0da7a63bf918 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -20,11 +20,9 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; -import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfEnums; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE; @@ -38,20 +36,20 @@ * * @param model the id of the model to use in the requests to cohere * @param inputType Specifies the type of input you're giving to the model - * @param embeddingTypes Specifies the types of embeddings you want to get back + * @param embeddingType Specifies the type of embeddings you want to get back (we only support retrieving a single type) * @param truncation Specifies how the API will handle inputs longer than the maximum token length */ public record CohereEmbeddingsTaskSettings( @Nullable String model, @Nullable InputType inputType, - @Nullable List embeddingTypes, + @Nullable CohereEmbeddingType embeddingType, @Nullable CohereTruncation truncation ) implements TaskSettings { public static final String NAME = "cohere_embeddings_task_settings"; - static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null); + public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null); static final String INPUT_TYPE = "input_type"; - static final String EMBEDDING_TYPES = "embedding_types"; + static final String EMBEDDING_TYPE = "embedding_type"; public static CohereEmbeddingsTaskSettings fromMap(Map map) { if (map.isEmpty()) { @@ -69,9 +67,9 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { InputType.values(), validationException ); - List embeddingTypes = extractOptionalListOfEnums( + CohereEmbeddingType embeddingTypes = extractOptionalEnum( map, - EMBEDDING_TYPES, + EMBEDDING_TYPE, ModelConfigurations.TASK_SETTINGS, CohereEmbeddingType::fromString, CohereEmbeddingType.values(), @@ -94,12 +92,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { } public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { - this( - in.readOptionalString(), - InputType.fromStream(in), - in.readOptionalCollectionAsList(CohereEmbeddingType::fromStream), - CohereTruncation.fromStream(in) - ); + this(in.readOptionalString(), InputType.fromStream(in), CohereEmbeddingType.fromStream(in), CohereTruncation.fromStream(in)); } @Override @@ -113,8 +106,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_TYPE, inputType); } - if (embeddingTypes != null) { - builder.field(EMBEDDING_TYPES, embeddingTypes); + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE, embeddingType); } if (truncation != null) { @@ -138,14 +131,14 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(model); inputType.writeTo(out); - out.writeOptionalCollection(embeddingTypes); + embeddingType.writeTo(out); truncation.writeTo(out); } public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model(); var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType(); - var embeddingTypesToUse = requestTaskSettings.embeddingTypes() == null ? embeddingTypes : requestTaskSettings.embeddingTypes(); + var embeddingTypesToUse = requestTaskSettings.embeddingType() == null ? embeddingType : requestTaskSettings.embeddingType(); var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation(); return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java new file mode 100644 index 0000000000000..e1daa78454d25 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.cohere; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CohereActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_OpenAiEmbeddingsModel() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + 1024, + 1024 + ); + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap( + "model", + InputType.SEARCH, + CohereEmbeddingType.INT8, + CohereTruncation.END + ); + var action = actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_query", + "embedding_types", + List.of("int8"), + "truncate", + "end" + ) + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java new file mode 100644 index 0000000000000..d2a8aad19b4c8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -0,0 +1,333 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.cohere; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBytes; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class CohereEmbeddingsActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float"), + "truncate", + "start" + ) + ) + ); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "int8": [ + [ + 0, + -1 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.START), + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 0, (byte) -1))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("int8"), + "truncate", + "start" + ) + ) + ); + } + } + + public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { + try (var sender = mock(Sender.class)) { + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender) + ); + MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any()); + + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsExceptionWithNullUrl() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any()); + + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(List.of("abc"), listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + } + + private CohereEmbeddingsAction createAction(String url, String apiKey, CohereEmbeddingsTaskSettings taskSettings, Sender sender) { + var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024); + + return new CohereEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index 95b69f1231e9d..c17b293c8bc35 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -331,7 +331,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(2)); { @@ -389,7 +389,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 23b6f1ea2fbe3..2d7493a783c51 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -33,7 +33,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap; @@ -100,7 +100,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -169,7 +169,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(2)); { assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -252,7 +252,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(2)); { assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -320,7 +320,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 6bc8e2d61d579..2d5cea7bd981e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -38,7 +38,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.equalTo; @@ -104,7 +104,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java new file mode 100644 index 0000000000000..9b7edd12492ea --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CohereResponseHandlerTests extends ESTestCase { + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200); + } + + public void testCheckForFailureStatusCode_ThrowsFor503() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503)); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request [null] status [503]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429)); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code. Monthly request limit") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400)); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request [null] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400_TextsTooLarge() { + var exception = expectThrows( + RetryException.class, + () -> callCheckForFailureStatusCode(400, "invalid request: total number of texts must be at most 96 - received 100") + ); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a texts array too large response for request [null] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor401() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401)); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an authentication error status code for request [null] status [401]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300)); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request [null] status [300]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + public void testBuildRateLimitErrorMessage() { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(429); + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + var httpResult = new HttpResult(response, new byte[] {}); + + when(response.getFirstHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT)).thenReturn( + new BasicHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT, "3000") + ); + when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE)).thenReturn( + new BasicHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE, "2999") + ); + when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING)).thenReturn( + new BasicHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING, "12") + ); + + var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult); + MatcherAssert.assertThat( + error, + containsString("Monthly request limit [3000], permitted requests per minute [2999], remaining requests [12]") + ); + } + + public void testBuildRateLimitErrorMessage_FillsWithUnknown_WhenUnableToFindHeader() { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(429); + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + var httpResult = new HttpResult(response, new byte[] {}); + + var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult); + MatcherAssert.assertThat( + error, + containsString("Monthly request limit [unknown], permitted requests per minute [unknown], remaining requests [unknown]") + ); + } + + private static void callCheckForFailureStatusCode(int statusCode) { + callCheckForFailureStatusCode(statusCode, null); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + String responseJson = Strings.format(""" + { + "message": "%s" + } + """, errorMessage); + + var httpRequest = mock(HttpRequestBase.class); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); + var handler = new CohereResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(httpRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java index bb9612f01d8ff..0c4b354e7171f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java @@ -40,7 +40,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequestTests.createRequest; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -104,7 +104,7 @@ public void testSend_SuccessfulResponse() throws IOException, URISyntaxException var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -155,7 +155,7 @@ public void testSend_SuccessfulResponse_WithoutUser() throws IOException, URISyn var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -205,7 +205,7 @@ public void testSend_SuccessfulResponse_WithoutOrganization() throws IOException var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..01a2290446529 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class CohereEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereEmbeddingsRequestEntity( + List.of("abc"), + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throws IOException { + var entity = new CohereEmbeddingsRequestEntity( + List.of("abc"), + new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"texts":["abc"]}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..903f6a9500831 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_UrlDefined() throws URISyntaxException, IOException { + var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); + + var httpRequest = request.createRequest(); + MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest; + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc")))); + } + + public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOException { + var request = createRequest( + "url", + "secret", + List.of("abc"), + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START) + ); + + var httpRequest = request.createRequest(); + MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest; + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float"), + "truncate", + "start" + ) + ) + ); + } + + public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() throws URISyntaxException, IOException { + var request = createRequest( + "url", + "secret", + List.of("abc"), + new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.END) + ); + + var httpRequest = request.createRequest(); + MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest; + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_query", + "embedding_types", + List.of("int8"), + "truncate", + "end" + ) + ) + ); + } + + public void testCreateRequest_TruncateNone() throws URISyntaxException, IOException { + var request = createRequest( + "url", + "secret", + List.of("abc"), + new CohereEmbeddingsTaskSettings(null, null, null, CohereTruncation.NONE) + ); + + var httpRequest = request.createRequest(); + MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest; + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none"))); + } + + public static CohereEmbeddingsRequest createRequest( + @Nullable String url, + String apiKey, + List input, + CohereEmbeddingsTaskSettings taskSettings + ) throws URISyntaxException { + var uri = url == null ? null : new URI(url); + + var account = new CohereAccount(uri, new SecureString(apiKey.toCharArray())); + return new CohereEmbeddingsRequest(account, input, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..1d7af8577a722 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -0,0 +1,434 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.cohere; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CohereEmbeddingsResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": [ + [ + -0.0018434525, + 0.01777649 + ] + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + ); + } + + public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + -0.0018434525, + 0.01777649 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + ); + } + + public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + -0.0018434525, + 0.01777649 + ] + ], + "int8": [ + [ + -1, + 0 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + ); + } + + public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFirst() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "invalid_type": [ + [ + -0.0018434525, + 0.01777649 + ] + ], + "int8": [ + [ + -1, + 0 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) -1, (byte) 0)))) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": [ + [ + -0.0018434525, + 0.01777649 + ], + [ + -0.123, + 0.123 + ] + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is( + List.of( + TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F)) + ) + ) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + -0.0018434525, + 0.01777649 + ], + [ + -0.123, + 0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is( + List.of( + TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F)) + ) + ) + ); + } + + public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings_not_here": [ + [ + -0.0018434525, + 0.01777649 + ] + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to find required field [embeddings] in Cohere embeddings response") + ); + } + + public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Negative() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "int8": [ + [ + -129, + 127 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte")); + } + + public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Positive() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "int8": [ + [ + -128, + 128 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte")); + } + + public void testFromResponse_FailsToFindAValidEmbeddingType() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "invalid_type": [ + [ + -0.0018434525, + 0.01777649 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java new file mode 100644 index 0000000000000..a2b1c26b2b3d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.cohere; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CohereErrorResponseEntityTests extends ESTestCase { + public void testFromResponse() { + String responseJson = """ + { + "message": "invalid request: total number of texts must be at most 96 - received 97" + } + """; + + CohereErrorResponseEntity errorMessage = CohereErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorMessage); + MatcherAssert.assertThat( + errorMessage.getErrorMessage(), + is("invalid request: total number of texts must be at most 96 - received 97") + ); + } + + public void testFromResponse_noMessage() { + String responseJson = """ + { + "error": "abc" + } + """; + + CohereErrorResponseEntity errorMessage = CohereErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNull(errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index 2b6e11fdfafa7..d54bcd91c9eda 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -37,7 +37,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException { @@ -57,7 +57,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws IOException { @@ -83,8 +83,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws parsedResults.embeddings(), is( List.of( - new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), - new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) + TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) ) ) ); @@ -115,8 +115,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), - new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) + TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) ) ) ); @@ -254,7 +254,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException { @@ -273,7 +273,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException { @@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException { @@ -309,7 +309,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 010e990a3ce80..3e8b50591e3b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -49,7 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), - new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) + TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) ) ) ); @@ -261,7 +261,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { @@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java new file mode 100644 index 0000000000000..b89c3c9ab0ec9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteValue; + +import java.io.IOException; + +public class ByteValueTests extends AbstractWireSerializingTestCase { + public static ByteValue createRandom() { + return new ByteValue(randomByte()); + } + + @Override + protected Writeable.Reader instanceReader() { + return ByteValue::new; + } + + @Override + protected ByteValue createTestInstance() { + return createRandom(); + } + + @Override + protected ByteValue mutateInstance(ByteValue instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java new file mode 100644 index 0000000000000..1d37c0d3ddee1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.FloatValue; + +import java.io.IOException; + +public class FloatValueTests extends AbstractWireSerializingTestCase { + public static FloatValue createRandom() { + return new FloatValue(randomFloat()); + } + + @Override + protected Writeable.Reader instanceReader() { + return FloatValue::new; + } + + @Override + protected FloatValue createTestInstance() { + return createRandom(); + } + + @Override + protected FloatValue mutateInstance(FloatValue instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 09d9894d98853..1bb237b87234b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -7,31 +7,80 @@ package org.elasticsearch.xpack.inference.results; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteValue; +import org.elasticsearch.xpack.core.inference.results.FloatValue; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED; import static org.hamcrest.Matchers.is; -public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase { +public class TextEmbeddingResultsTests extends AbstractBWCWireSerializationTestCase { + + private enum EmbeddingType { + FLOAT, + BYTE + } + + private static Map> EMBEDDING_TYPE_BUILDERS = Map.of( + EmbeddingType.FLOAT, + TextEmbeddingResultsTests::createRandomFloatEmbedding, + EmbeddingType.BYTE, + TextEmbeddingResultsTests::createRandomByteEmbedding + ); + public static TextEmbeddingResults createRandomResults() { + var embeddingType = randomFrom(EmbeddingType.values()); + var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType); + + if (createFunction == null) { + createFunction = TextEmbeddingResultsTests::createRandomFloatEmbedding; + } + + return createRandomResults(createFunction); + } + + private static TextEmbeddingResults createRandomResults(Supplier creator) { int embeddings = randomIntBetween(1, 10); List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { - embeddingResults.add(createRandomEmbedding()); + embeddingResults.add(creator.get()); } return new TextEmbeddingResults(embeddingResults); } - private static TextEmbeddingResults.Embedding createRandomEmbedding() { + private static TextEmbeddingResults.Embedding createRandomEmbedding(boolean makeFloatEmbedding) { + return makeFloatEmbedding ? createRandomFloatEmbedding() : createRandomByteEmbedding(); + } + + private static TextEmbeddingResults.Embedding createRandomByteEmbedding() { + int columns = randomIntBetween(1, 10); + List bytes = new ArrayList<>(columns); + + for (int i = 0; i < columns; i++) { + bytes.add(randomByte()); + } + + return TextEmbeddingResults.Embedding.ofBytes(bytes); + } + + private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() { int columns = randomIntBetween(1, 10); List floats = new ArrayList<>(columns); @@ -39,19 +88,34 @@ private static TextEmbeddingResults.Embedding createRandomEmbedding() { floats.add(randomFloat()); } - return new TextEmbeddingResults.Embedding(floats); + return TextEmbeddingResults.Embedding.ofFloats(floats); } - public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)))); + public static Map buildExpectationFloats(List> embeddings) { + return Map.of( + TextEmbeddingResults.TEXT_EMBEDDING, + embeddings.stream() + .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(FloatValue::new).toList())) + .toList() + ); + } - assertThat( - entity.asMap(), - is(Map.of(TextEmbeddingResults.TEXT_EMBEDDING, List.of(Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F))))) + public static Map buildExpectationBytes(List> embeddings) { + return Map.of( + TextEmbeddingResults.TEXT_EMBEDDING, + embeddings.stream() + .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(ByteValue::new).toList())) + .toList() ); + } + + public void testToXContent_CreatesTheRightFormatForASingleEmbedding() { + var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)))); + + MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F))))); String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, is(""" { "text_embedding" : [ { @@ -63,27 +127,33 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE }""")); } - public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingResults( - List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)), new TextEmbeddingResults.Embedding(List.of(0.2F))) + public void testToXContent_CreatesTheRightFormatForASingleEmbedding_ForBytes() { + var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)))); - ); + MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12))))); - assertThat( - entity.asMap(), - is( - Map.of( - TextEmbeddingResults.TEXT_EMBEDDING, - List.of( - Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)), - Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.2F)) - ) - ) - ) + String xContentResult = Strings.toString(entity, true, true); + MatcherAssert.assertThat(xContentResult, is(""" + { + "text_embedding" : [ + { + "embedding" : [ + 12 + ] + } + ] + }""")); + } + + public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() { + var entity = new TextEmbeddingResults( + List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)), TextEmbeddingResults.Embedding.ofFloats(List.of(0.2F))) ); + MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F), List.of(0.2F))))); + String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, is(""" { "text_embedding" : [ { @@ -100,12 +170,40 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I }""")); } + public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings_ForBytes() { + var entity = new TextEmbeddingResults( + List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)), TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 34))) + ); + + MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12), List.of((byte) 34))))); + + String xContentResult = Strings.toString(entity, true, true); + MatcherAssert.assertThat(xContentResult, is(""" + { + "text_embedding" : [ + { + "embedding" : [ + 12 + ] + }, + { + "embedding" : [ + 34 + ] + } + ] + }""")); + } + public void testTransformToCoordinationFormat() { var results = new TextEmbeddingResults( - List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F))) + List.of( + TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F, 0.2F)), + TextEmbeddingResults.Embedding.ofFloats(List.of(0.3F, 0.4F)) + ) ).transformToCoordinationFormat(); - assertThat( + MatcherAssert.assertThat( results, is( List.of( @@ -124,6 +222,49 @@ public void testTransformToCoordinationFormat() { ); } + public void testTransformToCoordinationFormat_FromBytes() { + var results = new TextEmbeddingResults( + List.of( + TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12, (byte) 34)), + TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 56, (byte) -78)) + ) + ).transformToCoordinationFormat(); + + MatcherAssert.assertThat( + results, + is( + List.of( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 12F, 34F }, + false + ), + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 56F, -78F }, + false + ) + ) + ) + ); + } + + public void testSerializesToFloats_WhenVersionIsPriorToByteSupport() throws IOException { + var instance = createRandomResults(TextEmbeddingResultsTests::createRandomByteEmbedding); + var modifiedForOlderVersion = mutateInstanceForVersion(instance, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); + + var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); + assertOnBWCObject(copy, modifiedForOlderVersion, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + @Override protected Writeable.Reader instanceReader() { return TextEmbeddingResults::new; @@ -143,15 +284,26 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr return new TextEmbeddingResults(instance.embeddings().subList(0, end)); } else { List embeddings = new ArrayList<>(instance.embeddings()); - embeddings.add(createRandomEmbedding()); + embeddings.add(createRandomEmbedding(randomBoolean())); return new TextEmbeddingResults(embeddings); } } - public static Map buildExpectation(List> embeddings) { - return Map.of( - TextEmbeddingResults.TEXT_EMBEDDING, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding)).toList() - ); + @Override + protected TextEmbeddingResults mutateInstanceForVersion(TextEmbeddingResults instance, TransportVersion version) { + if (version.before(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { + return convertToFloatEmbeddings(instance); + } + + return instance; + } + + public TextEmbeddingResults convertToFloatEmbeddings(TextEmbeddingResults results) { + var floatEmbeddings = results.embeddings() + .stream() + .map(embedding -> TextEmbeddingResults.Embedding.ofFloats(embedding.toFloats())) + .toList(); + + return new TextEmbeddingResults(floatEmbeddings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java new file mode 100644 index 0000000000000..2061ac646c56b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -0,0 +1,853 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.apache.http.HttpHeaders; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class CohereServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var model = service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMap("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + getSecretSettingsMap("secret") + ), + Set.of() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)) + ); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), + Set.of() + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The [cohere] service does not support task type [sparse_embedding]") + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); + config.put("extra_key", "value"); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service") + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var serviceSettings = getServiceSettingsMap("url"); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + getTaskSettingsMap("model", null, null, null), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service") + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var taskSettingsMap = getTaskSettingsMap("model", null, null, null); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret")); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service") + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service") + ); + } + } + + public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var model = service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap(null), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), + Set.of() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMap("model", null, null, null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(null), + getTaskSettingsMap(null, InputType.INGEST, null, null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings(null, InputType.INGEST, null, null)) + ); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMap("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE), + getSecretSettingsMap("secret") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE)) + ); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMap("model", null, null, null), + getSecretSettingsMap("secret") + ); + persistedConfig.secrets.put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var serviceSettingsMap = getServiceSettingsMap("url"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var taskSettingsMap = getTaskSettingsMap("model", InputType.SEARCH, null, null); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, null, null)) + ); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + getTaskSettingsMap("model", null, null, CohereTruncation.NONE) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.NONE)) + ); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty()); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(null), + getTaskSettingsMap("model", null, CohereEmbeddingType.FLOAT, null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", null, CohereEmbeddingType.FLOAT, null)) + ); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty()); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var serviceSettingsMap = getServiceSettingsMap("url"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(null, InputType.SEARCH, null, null)); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings(null, InputType.SEARCH, null, null)) + ); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try ( + var service = new CohereService( + new SetOnce<>(mock(HttpRequestSenderFactory.class)), + new SetOnce<>(createWithEmptySettings(threadPool)) + ) + ) { + var taskSettingsMap = getTaskSettingsMap("model", InputType.INGEST, null, null); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + + var embeddingsModel = (CohereEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat( + embeddingsModel.getTaskSettings(), + is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null)) + ); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSenderFactory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new CohereService(new SetOnce<>(factory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(mockModel, List.of(""), new HashMap<>(), listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null), + 1024, + 1024 + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, List.of("abc"), new HashMap<>(), listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document"))); + } + } + + public void testCheckModelConfig_UpdatesDimensions() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1 + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is(CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2)) + ); + } + } + + public void testInfer_UnauthorisedResponse() throws IOException { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); + + try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) { + + String responseJson = """ + { + "message": "invalid api token" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024 + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, List.of("abc"), new HashMap<>(), listener); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [invalid api token]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private PeristedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PeristedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PeristedConfig(Map config, Map secrets) {} +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java new file mode 100644 index 0000000000000..600ddb54eddd3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; + +public class CohereTruncationTests extends AbstractWireSerializingTestCase { + public static CohereTruncation createRandom() { + return randomFrom(CohereTruncation.values()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereTruncation::fromStream; + } + + @Override + protected CohereTruncation createTestInstance() { + return createRandom(); + } + + @Override + protected CohereTruncation mutateInstance(CohereTruncation instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java new file mode 100644 index 0000000000000..8b907746b0779 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; + +public class CohereEmbeddingTypeTests extends AbstractWireSerializingTestCase { + public static CohereEmbeddingType createRandom() { + return randomFrom(CohereEmbeddingType.values()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereEmbeddingType::fromStream; + } + + @Override + protected CohereEmbeddingType createTestInstance() { + return createRandom(); + } + + @Override + protected CohereEmbeddingType mutateInstance(CohereEmbeddingType instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index f2ea1a8c4f3fe..84e65ed87e372 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -29,10 +29,10 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe public static CohereEmbeddingsTaskSettings createRandom() { var model = randomBoolean() ? randomAlphaOfLength(15) : null; var inputType = randomBoolean() ? randomFrom(InputType.values()) : null; - var embeddingTypes = randomBoolean() ? List.of(randomFrom(CohereEmbeddingType.values())) : null; + var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null; var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; - return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation); + return new CohereEmbeddingsTaskSettings(model, inputType, embeddingType, truncation); } public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { @@ -51,21 +51,14 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { "abc", CohereEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString().toLowerCase(Locale.ROOT), - CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, - List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8), + CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, + CohereEmbeddingType.INT8, CohereServiceFields.TRUNCATE, CohereTruncation.END.toString().toLowerCase(Locale.ROOT) ) ) ), - is( - new CohereEmbeddingsTaskSettings( - "abc", - InputType.INGEST, - List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8), - CohereTruncation.END - ) - ) + is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.END)) ); } @@ -84,7 +77,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { var exception = expectThrows( ValidationException.class, - () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc")))) + () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc")))) ); MatcherAssert.assertThat( @@ -133,10 +126,14 @@ protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettin return null; } + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + public static Map getTaskSettingsMap( @Nullable String model, @Nullable InputType inputType, - @Nullable List embeddingTypes, + @Nullable CohereEmbeddingType embeddingType, @Nullable CohereTruncation truncation ) { var map = new HashMap(); @@ -146,15 +143,15 @@ public static Map getTaskSettingsMap( } if (inputType != null) { - map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType); + map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); } - if (embeddingTypes != null) { - map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, embeddingTypes); + if (embeddingType != null) { + map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, embeddingType.toString()); } if (truncation != null) { - map.put(CohereServiceFields.TRUNCATE, truncation); + map.put(CohereServiceFields.TRUNCATE, truncation.toString()); } return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index a76cce41b4fe4..aac50b8645993 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -496,7 +496,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(-0.0123F, 0.0123F))))); + assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(-0.0123F, 0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 394286ee5287b..ab4ee881a224e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -46,7 +46,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceSettingsTests.getServiceSettingsMap; @@ -717,7 +717,7 @@ public void testInfer_SendsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); From 153785af8e27bde3f147508b23040f5ac3a119f9 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 18 Jan 2024 14:23:08 -0500 Subject: [PATCH 05/13] Fixing tests --- .../elasticsearch/inference/InputType.java | 16 +-------- .../inference/action/InferenceAction.java | 4 +-- .../action/openai/OpenAiEmbeddingsAction.java | 2 +- .../CohereEmbeddingsResponseEntity.java | 2 ++ .../services/cohere/CohereTruncation.java | 18 +--------- .../embeddings/CohereEmbeddingType.java | 18 +--------- .../CohereEmbeddingsTaskSettings.java | 13 ++++--- .../action/InferenceActionResponseTests.java | 6 ++-- .../results/TextEmbeddingResultsTests.java | 10 +++--- .../cohere/CohereTruncationTests.java | 34 ------------------- .../embeddings/CohereEmbeddingTypeTests.java | 34 ------------------- .../CohereEmbeddingsTaskSettingsTests.java | 15 ++++---- 12 files changed, 31 insertions(+), 141 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index b5ba3f0d1b506..ffc67995c1dda 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -8,17 +8,12 @@ package org.elasticsearch.inference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; import java.util.Locale; /** * Defines the type of request, whether the request is to ingest a document or search for a document. */ -public enum InputType implements Writeable { +public enum InputType { INGEST, SEARCH; @@ -32,13 +27,4 @@ public String toString() { public static InputType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } - - public static InputType fromStream(StreamInput in) throws IOException { - return in.readOptionalEnum(InputType.class); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(this); - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 732bc3d66bedc..30375e36a0e1d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -88,7 +88,7 @@ public Request(StreamInput in) throws IOException { } this.taskSettings = in.readGenericMap(); if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { - this.inputType = InputType.fromStream(in); + this.inputType = in.readEnum(InputType.class); } else { this.inputType = InputType.INGEST; } @@ -141,7 +141,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeGenericMap(taskSettings); if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { - inputType.writeTo(out); + out.writeEnum(inputType); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java index 417935f8c920c..3877157188832 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java @@ -43,7 +43,7 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic this.model.getSecretSettings().apiKey() ); this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents)); - this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "send OpenAI embeddings"); + this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings"); this.truncator = Objects.requireNonNull(serviceComponents.truncator()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index 8f6d6c75d8c71..b9830b72de010 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -45,6 +46,7 @@ public class CohereEmbeddingsResponseEntity { private static String supportedEmbeddingTypes() { var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new); + Arrays.sort(validTypes); return String.join(", ", validTypes); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java index 3fec21a9e03b0..e7c9a0247bb1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java @@ -7,11 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; import java.util.Locale; /** @@ -22,7 +17,7 @@ * See api docs for details. *

*/ -public enum CohereTruncation implements Writeable { +public enum CohereTruncation { /** * When the input exceeds the maximum input token length an error will be returned. */ @@ -36,8 +31,6 @@ public enum CohereTruncation implements Writeable { */ END; - public static String NAME = "cohere_truncate"; - @Override public String toString() { return name().toLowerCase(Locale.ROOT); @@ -46,13 +39,4 @@ public String toString() { public static CohereTruncation fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } - - public static CohereTruncation fromStream(StreamInput in) throws IOException { - return in.readOptionalEnum(CohereTruncation.class); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(this); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java index 803382bb8c947..82d57cfb92381 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java @@ -7,11 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.embeddings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; import java.util.Locale; /** @@ -21,7 +16,7 @@ * See api docs for details. *

*/ -public enum CohereEmbeddingType implements Writeable { +public enum CohereEmbeddingType { /** * Use this when you want to get back the default float embeddings. Valid for all models. */ @@ -31,8 +26,6 @@ public enum CohereEmbeddingType implements Writeable { */ INT8; - public static String NAME = "cohere_embedding_type"; - @Override public String toString() { return name().toLowerCase(Locale.ROOT); @@ -45,13 +38,4 @@ public static String toLowerCase(CohereEmbeddingType type) { public static CohereEmbeddingType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } - - public static CohereEmbeddingType fromStream(StreamInput in) throws IOException { - return in.readOptionalEnum(CohereEmbeddingType.class); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(this); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index b0da7a63bf918..33bd4095c47cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -92,7 +92,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { } public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalString(), InputType.fromStream(in), CohereEmbeddingType.fromStream(in), CohereTruncation.fromStream(in)); + this( + in.readOptionalString(), + in.readOptionalEnum(InputType.class), + in.readOptionalEnum(CohereEmbeddingType.class), + in.readOptionalEnum(CohereTruncation.class) + ); } @Override @@ -130,9 +135,9 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(model); - inputType.writeTo(out); - embeddingType.writeTo(out); - truncation.writeTo(out); + out.writeOptionalEnum(inputType); + out.writeOptionalEnum(embeddingType); + out.writeOptionalEnum(truncation); } public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 759411cec1212..622cf51798609 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -46,7 +46,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceAction.Response createTestInstance() { var result = switch (randomIntBetween(0, 2)) { - case 0 -> TextEmbeddingResultsTests.createRandomResults(); + case 0 -> TextEmbeddingResultsTests.createRandomFloatResults(); case 1 -> LegacyTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults(); default -> SparseEmbeddingResultsTests.createRandomResults(); }; @@ -90,7 +90,7 @@ public void testSerializesOpenAiAddedVersion_UsingSparseEmbeddingResult() throws } public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); + var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), INFERENCE_MULTIPLE_INPUTS); assertOnBWCObject(copy, instance, INFERENCE_MULTIPLE_INPUTS); @@ -106,7 +106,7 @@ public void testSerializesMultipleInputsVersion_UsingSparseEmbeddingResult() thr // Technically we should never see a text embedding result in the transport version of this test because support // for it wasn't added until openai public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); + var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED); assertOnBWCObject(copy, instance, ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 1bb237b87234b..c46c100c9a476 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -54,6 +54,10 @@ public static TextEmbeddingResults createRandomResults() { return createRandomResults(createFunction); } + public static TextEmbeddingResults createRandomFloatResults() { + return createRandomResults(TextEmbeddingResultsTests::createRandomFloatEmbedding); + } + private static TextEmbeddingResults createRandomResults(Supplier creator) { int embeddings = randomIntBetween(1, 10); List embeddingResults = new ArrayList<>(embeddings); @@ -65,10 +69,6 @@ private static TextEmbeddingResults createRandomResults(Supplier bytes = new ArrayList<>(columns); @@ -284,7 +284,7 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr return new TextEmbeddingResults(instance.embeddings().subList(0, end)); } else { List embeddings = new ArrayList<>(instance.embeddings()); - embeddings.add(createRandomEmbedding(randomBoolean())); + embeddings.add(createRandomFloatEmbedding()); return new TextEmbeddingResults(embeddings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java deleted file mode 100644 index 600ddb54eddd3..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.io.IOException; - -public class CohereTruncationTests extends AbstractWireSerializingTestCase { - public static CohereTruncation createRandom() { - return randomFrom(CohereTruncation.values()); - } - - @Override - protected Writeable.Reader instanceReader() { - return CohereTruncation::fromStream; - } - - @Override - protected CohereTruncation createTestInstance() { - return createRandom(); - } - - @Override - protected CohereTruncation mutateInstance(CohereTruncation instance) throws IOException { - return null; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java deleted file mode 100644 index 8b907746b0779..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.embeddings; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.io.IOException; - -public class CohereEmbeddingTypeTests extends AbstractWireSerializingTestCase { - public static CohereEmbeddingType createRandom() { - return randomFrom(CohereEmbeddingType.values()); - } - - @Override - protected Writeable.Reader instanceReader() { - return CohereEmbeddingType::fromStream; - } - - @Override - protected CohereEmbeddingType createTestInstance() { - return createRandom(); - } - - @Override - protected CohereEmbeddingType mutateInstance(CohereEmbeddingType instance) throws IOException { - return null; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index 84e65ed87e372..a568a0eaa1e5a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.cohere.embeddings; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; @@ -19,7 +20,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import static org.hamcrest.Matchers.is; @@ -50,11 +50,11 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { CohereServiceFields.MODEL, "abc", CohereEmbeddingsTaskSettings.INPUT_TYPE, - InputType.INGEST.toString().toLowerCase(Locale.ROOT), + InputType.INGEST.toString(), CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, - CohereEmbeddingType.INT8, + CohereEmbeddingType.INT8.toString(), CohereServiceFields.TRUNCATE, - CohereTruncation.END.toString().toLowerCase(Locale.ROOT) + CohereTruncation.END.toString() ) ) ), @@ -76,16 +76,13 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { var exception = expectThrows( - ValidationException.class, + ElasticsearchStatusException.class, () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc")))) ); MatcherAssert.assertThat( exception.getMessage(), - is( - "Validation Failed: 1: [task_settings] Invalid type [Integer]" - + " received for value [123]. [embedding_types] must be type [String];" - ) + is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]") ); } From 4a127ff1db8b4b054dd5772654a4d7f1c01cbfb1 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 18 Jan 2024 16:11:12 -0500 Subject: [PATCH 06/13] Removing rate limit error message --- .../cohere/CohereResponseHandler.java | 37 ++++++----------- .../cohere/CohereResponseHandlerTests.java | 41 +------------------ 2 files changed, 14 insertions(+), 64 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index 299cad20fb012..211550f8dbc08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -9,7 +9,6 @@ import org.apache.http.client.methods.HttpRequestBase; import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; @@ -18,14 +17,20 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; -import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; +/** + * Defines how to handle various errors returned from the Cohere integration. + * + * NOTE: + * These headers are returned for trial API keys only (they also do not exist within 429 responses) + * + * + * x-endpoint-monthly-call-limit + * x-trial-endpoint-call-limit + * x-trial-endpoint-call-remaining + * + */ public class CohereResponseHandler extends BaseResponseHandler { - - static final String MONTHLY_REQUESTS_LIMIT = "x-endpoint-monthly-call-limit"; - // TODO determine the production versions of these - static final String TRIAL_REQUEST_LIMIT_PER_MINUTE = "x-trial-endpoint-call-limit"; - static final String TRIAL_REQUESTS_REMAINING = "x-trial-endpoint-call-remaining"; static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; @@ -58,7 +63,7 @@ void checkForFailureStatusCode(HttpRequestBase request, HttpResult result) throw if (statusCode >= 500) { throw new RetryException(false, buildError(SERVER_ERROR, request, result)); } else if (statusCode == 429) { - throw new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); } else if (isTextsArrayTooLarge(result)) { throw new RetryException(false, buildError(TEXTS_ARRAY_ERROR_MESSAGE, request, result)); } else if (statusCode == 401) { @@ -80,20 +85,4 @@ private static boolean isTextsArrayTooLarge(HttpResult result) { return false; } - - static String buildRateLimitErrorMessage(HttpResult result) { - var response = result.response(); - var monthlyRequestLimit = getFirstHeaderOrUnknown(response, MONTHLY_REQUESTS_LIMIT); - var trialRequestsPerMinute = getFirstHeaderOrUnknown(response, TRIAL_REQUEST_LIMIT_PER_MINUTE); - var trialRequestsRemaining = getFirstHeaderOrUnknown(response, TRIAL_REQUESTS_REMAINING); - - var usageMessage = Strings.format( - "Monthly request limit [%s], permitted requests per minute [%s], remaining requests [%s]", - monthlyRequestLimit, - trialRequestsPerMinute, - trialRequestsRemaining - ); - - return RATE_LIMIT + ". " + usageMessage; - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java index 9b7edd12492ea..8873f603dd082 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java @@ -12,7 +12,6 @@ import org.apache.http.HttpResponse; import org.apache.http.StatusLine; import org.apache.http.client.methods.HttpRequestBase; -import org.apache.http.message.BasicHeader; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; @@ -50,7 +49,7 @@ public void testCheckForFailureStatusCode_ThrowsFor429() { assertTrue(exception.shouldRetry()); MatcherAssert.assertThat( exception.getCause().getMessage(), - containsString("Received a rate limit status code. Monthly request limit") + containsString("Received a rate limit status code for request [null] status [429]") ); MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); } @@ -98,44 +97,6 @@ public void testCheckForFailureStatusCode_ThrowsFor300() { MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); } - public void testBuildRateLimitErrorMessage() { - var statusLine = mock(StatusLine.class); - when(statusLine.getStatusCode()).thenReturn(429); - var response = mock(HttpResponse.class); - when(response.getStatusLine()).thenReturn(statusLine); - var httpResult = new HttpResult(response, new byte[] {}); - - when(response.getFirstHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT)).thenReturn( - new BasicHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT, "3000") - ); - when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE)).thenReturn( - new BasicHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE, "2999") - ); - when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING)).thenReturn( - new BasicHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING, "12") - ); - - var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult); - MatcherAssert.assertThat( - error, - containsString("Monthly request limit [3000], permitted requests per minute [2999], remaining requests [12]") - ); - } - - public void testBuildRateLimitErrorMessage_FillsWithUnknown_WhenUnableToFindHeader() { - var statusLine = mock(StatusLine.class); - when(statusLine.getStatusCode()).thenReturn(429); - var response = mock(HttpResponse.class); - when(response.getStatusLine()).thenReturn(statusLine); - var httpResult = new HttpResult(response, new byte[] {}); - - var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult); - MatcherAssert.assertThat( - error, - containsString("Monthly request limit [unknown], permitted requests per minute [unknown], remaining requests [unknown]") - ); - } - private static void callCheckForFailureStatusCode(int statusCode) { callCheckForFailureStatusCode(statusCode, null); } From 3ebed6d73729c9e55a12a86170b49c8a04f94940 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 18 Jan 2024 16:38:01 -0500 Subject: [PATCH 07/13] Fixing a few comments --- .../inference/external/action/cohere/CohereActionCreator.java | 2 +- .../xpack/inference/external/cohere/CohereResponseHandler.java | 1 - .../xpack/inference/services/cohere/CohereServiceTests.java | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index b3ac96979b68b..8c9d70f0a7323 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -16,7 +16,7 @@ import java.util.Objects; /** - * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type. */ public class CohereActionCreator implements CohereActionVisitor { private final Sender sender; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index 211550f8dbc08..59cf6f4ca52f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -48,7 +48,6 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, H /** * Validates the status code throws an RetryException if not in the range [200, 300). * - * The OpenAI API error codes are documented here. * @param request The http request * @param result The http response and body * @throws RetryException Throws if status code is {@code >= 300 or < 200 } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 2061ac646c56b..26bc51683155c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -263,7 +263,7 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr } } - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { + public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() throws IOException { try ( var service = new CohereService( new SetOnce<>(mock(HttpRequestSenderFactory.class)), From a4e576f4a3cca762693b6b95b477bcece7888394 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 22 Jan 2024 08:37:50 -0500 Subject: [PATCH 08/13] Update docs/changelog/104559.yaml --- docs/changelog/104559.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/104559.yaml diff --git a/docs/changelog/104559.yaml b/docs/changelog/104559.yaml new file mode 100644 index 0000000000000..d6d030783c4cc --- /dev/null +++ b/docs/changelog/104559.yaml @@ -0,0 +1,5 @@ +pr: 104559 +summary: Adding support for Cohere inference service +area: Machine Learning +type: enhancement +issues: [] From 82001b2fad754bcdd2a137dfd6a39715c136d005 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 23 Jan 2024 15:17:54 -0500 Subject: [PATCH 09/13] Addressing most feedback --- .../InferenceNamedWriteablesProvider.java | 11 + .../external/action/ActionUtils.java | 2 +- .../action/cohere/CohereEmbeddingsAction.java | 29 ++- .../action/openai/OpenAiEmbeddingsAction.java | 4 +- .../external/request/RequestUtils.java | 7 +- .../cohere/CohereEmbeddingsRequest.java | 16 +- .../cohere/CohereEmbeddingsRequestEntity.java | 40 ++-- .../inference/services/ServiceUtils.java | 7 +- .../services/cohere/CohereService.java | 15 +- .../cohere/CohereServiceSettings.java | 55 +++-- .../embeddings/CohereEmbeddingsModel.java | 11 +- .../CohereEmbeddingsServiceSettings.java | 111 +++++++++ .../CohereEmbeddingsTaskSettings.java | 46 +--- .../HuggingFaceServiceSettings.java | 4 +- .../openai/OpenAiServiceSettings.java | 11 - .../cohere/CohereActionCreatorTests.java | 17 +- .../cohere/CohereEmbeddingsActionTests.java | 32 ++- .../CohereEmbeddingsRequestEntityTests.java | 10 +- .../cohere/CohereEmbeddingsRequestTests.java | 20 +- .../results/TextEmbeddingResultsTests.java | 5 +- .../inference/services/ServiceUtilsTests.java | 5 +- .../cohere/CohereServiceSettingsTests.java | 19 +- .../services/cohere/CohereServiceTests.java | 214 +++++++++++------- .../CohereEmbeddingsModelTests.java | 59 ++++- .../CohereEmbeddingsServiceSettingsTests.java | 167 ++++++++++++++ .../CohereEmbeddingsTaskSettingsTests.java | 45 +--- 26 files changed, 659 insertions(+), 303 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 02d19ba60b0e7..f2f604e22bc24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; @@ -98,6 +99,16 @@ public static List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(CohereServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + CohereEmbeddingsServiceSettings.NAME, + CohereEmbeddingsServiceSettings::new + ) + ); namedWriteables.add( new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java index 4a8519934c63c..e4d6e39fdf1f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java @@ -39,7 +39,7 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e); } - public static String getErrorMessage(@Nullable URI uri, String message) { + public static String constructFailedToSendRequestMessage(@Nullable URI uri, String message) { if (uri != null) { return Strings.format("Failed to send %s request to [%s]", message, uri); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index 62afbc190d573..55611ab06a641 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -27,8 +27,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; public class CohereEmbeddingsAction implements ExecutableAction { @@ -37,13 +37,19 @@ public class CohereEmbeddingsAction implements ExecutableAction { private final CohereAccount account; private final CohereEmbeddingsModel model; - private final String errorMessage; + private final String failedToSendRequestErrorMessage; private final RetryingHttpSender sender; public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) { this.model = Objects.requireNonNull(model); - this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey()); - this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "Cohere embeddings"); + this.account = new CohereAccount( + this.model.getServiceSettings().getCommonSettings().getUri(), + this.model.getSecretSettings().apiKey() + ); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + this.model.getServiceSettings().getCommonSettings().getUri(), + "Cohere embeddings" + ); this.sender = new RetryingHttpSender( Objects.requireNonNull(sender), serviceComponents.throttlerManager(), @@ -56,14 +62,23 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Servic @Override public void execute(List input, ActionListener listener) { try { - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings()); - ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + CohereEmbeddingsRequest request = new CohereEmbeddingsRequest( + account, + input, + model.getTaskSettings(), + model.getServiceSettings().getCommonSettings().getModel(), + model.getServiceSettings().getEmbeddingType() + ); + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); sender.send(request, HANDLER, wrappedListener); } catch (ElasticsearchException e) { listener.onFailure(e); } catch (Exception e) { - listener.onFailure(createInternalServerError(e, errorMessage)); + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java index 3877157188832..f7bcd53724168 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java @@ -23,8 +23,8 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.common.Truncator.truncate; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; public class OpenAiEmbeddingsAction implements ExecutableAction { @@ -43,7 +43,7 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic this.model.getSecretSettings().apiKey() ); this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents)); - this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings"); + this.errorMessage = constructFailedToSendRequestMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings"); this.truncator = Objects.requireNonNull(serviceComponents.truncator()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java index 4cb32b7bc95fd..6116b1cc234c6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java @@ -29,11 +29,8 @@ public static URI buildUri(URI accountUri, String service, CheckedSupplier input; private final URI uri; private final CohereEmbeddingsTaskSettings taskSettings; + private final String model; + private final CohereEmbeddingType embeddingType; - public CohereEmbeddingsRequest(CohereAccount account, List input, CohereEmbeddingsTaskSettings taskSettings) { + public CohereEmbeddingsRequest( + CohereAccount account, + List input, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType + ) { this.account = Objects.requireNonNull(account); this.input = Objects.requireNonNull(input); this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri); this.taskSettings = Objects.requireNonNull(taskSettings); + this.model = model; + this.embeddingType = embeddingType; } @Override @@ -46,7 +58,7 @@ public HttpRequestBase createRequest() { HttpPost httpPost = new HttpPost(uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java index 831fb07b30e24..a7d8743359f39 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -7,30 +7,36 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import java.io.IOException; import java.util.List; -import java.util.Map; import java.util.Objects; -public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject { +public record CohereEmbeddingsRequestEntity( + List input, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType +) implements ToXContentObject { private static final String SEARCH_DOCUMENT = "search_document"; private static final String SEARCH_QUERY = "search_query"; /** - * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request + * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request using the enum's ordinal. + * The order of these entries is important and needs to match the order in the enum */ - private static final Map INPUT_TYPE_MAPPING = Map.of( - InputType.INGEST, - SEARCH_DOCUMENT, - InputType.SEARCH, - SEARCH_QUERY - ); + private static final String[] INPUT_TYPE_MAPPING = { SEARCH_DOCUMENT, SEARCH_QUERY }; + static { + assert INPUT_TYPE_MAPPING.length == InputType.values().length : "input type mapping was incorrectly defined"; + } + private static final String TEXTS_FIELD = "texts"; static final String INPUT_TYPE_FIELD = "input_type"; @@ -45,16 +51,16 @@ public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddings public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(TEXTS_FIELD, input); - if (taskSettings.model() != null) { - builder.field(CohereServiceFields.MODEL, taskSettings.model()); + if (model != null) { + builder.field(CohereServiceFields.MODEL, model); } if (taskSettings.inputType() != null) { builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType())); } - if (taskSettings.embeddingType() != null) { - builder.field(EMBEDDING_TYPES_FIELD, List.of(taskSettings.embeddingType())); + if (embeddingType != null) { + builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType)); } if (taskSettings.truncation() != null) { @@ -66,12 +72,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } private static String covertToString(InputType inputType) { - var stringValue = INPUT_TYPE_MAPPING.get(inputType); - - if (stringValue == null) { - return SEARCH_DOCUMENT; - } - - return stringValue; + return INPUT_TYPE_MAPPING[inputType.ordinal()]; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 3e9f6e1f75a8a..20912a3811af1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -138,8 +139,12 @@ public static String invalidValue(String settingName, String scope, String inval } // TODO improve URI validation logic - public static URI convertToUri(String url, String settingName, String settingScope, ValidationException validationException) { + public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { + if (url == null) { + return null; + } + return createUri(url); } catch (IllegalArgumentException ignored) { validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index e654316a4fbd4..8783f12852ec8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings; import java.util.List; import java.util.Map; @@ -157,11 +158,15 @@ public void checkModelConfig(Model model, ActionListener listener) { } private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) { - CohereServiceSettings serviceSettings = new CohereServiceSettings( - model.getServiceSettings().uri(), - SimilarityMeasure.DOT_PRODUCT, - embeddingSize, - model.getServiceSettings().maxInputTokens() + CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings( + new CohereServiceSettings( + model.getServiceSettings().getCommonSettings().getUri(), + SimilarityMeasure.DOT_PRODUCT, + embeddingSize, + model.getServiceSettings().getCommonSettings().getMaxInputTokens(), + model.getServiceSettings().getCommonSettings().getModel() + ), + model.getServiceSettings().getEmbeddingType() ); return new CohereEmbeddingsModel(model, serviceSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index bdd5e981b0490..f03371593a340 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -35,6 +35,7 @@ public class CohereServiceSettings implements ServiceSettings { public static final String NAME = "cohere_service_settings"; + public static final String MODEL = "model"; public static CohereServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -44,50 +45,44 @@ public static CohereServiceSettings fromMap(Map map) { SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - - // Throw if any of the settings were empty strings or invalid - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - // the url is optional and only for testing - if (url == null) { - return new CohereServiceSettings((URI) null, similarity, dims, maxInputTokens); - } - URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereServiceSettings(uri, similarity, dims, maxInputTokens); + return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, model); } private final URI uri; private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; + private final String model; public CohereServiceSettings( @Nullable URI uri, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens + @Nullable Integer maxInputTokens, + @Nullable String model ) { this.uri = uri; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; + this.model = model; } public CohereServiceSettings( @Nullable String url, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens + @Nullable Integer maxInputTokens, + @Nullable String model ) { - this(createOptionalUri(url), similarity, dimensions, maxInputTokens); + this(createOptionalUri(url), similarity, dimensions, maxInputTokens, model); } public CohereServiceSettings(StreamInput in) throws IOException { @@ -95,24 +90,29 @@ public CohereServiceSettings(StreamInput in) throws IOException { similarity = in.readOptionalEnum(SimilarityMeasure.class); dimensions = in.readOptionalVInt(); maxInputTokens = in.readOptionalVInt(); + model = in.readOptionalString(); } - public URI uri() { + public URI getUri() { return uri; } - public SimilarityMeasure similarity() { + public SimilarityMeasure getSimilarity() { return similarity; } - public Integer dimensions() { + public Integer getDimensions() { return dimensions; } - public Integer maxInputTokens() { + public Integer getMaxInputTokens() { return maxInputTokens; } + public String getModel() { + return model; + } + @Override public String getWriteableName() { return NAME; @@ -122,6 +122,13 @@ public String getWriteableName() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + toXContentFragment(builder); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder) throws IOException { if (uri != null) { builder.field(URL, uri.toString()); } @@ -134,8 +141,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + if (model != null) { + builder.field(MODEL, model); + } - builder.endObject(); return builder; } @@ -151,6 +160,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(similarity); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); + out.writeOptionalString(model); } @Override @@ -161,11 +171,12 @@ public boolean equals(Object o) { return Objects.equals(uri, that.uri) && Objects.equals(similarity, that.similarity) && Objects.equals(dimensions, that.dimensions) - && Objects.equals(maxInputTokens, that.maxInputTokens); + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(model, that.model); } @Override public int hashCode() { - return Objects.hash(uri, similarity, dimensions, maxInputTokens); + return Objects.hash(uri, similarity, dimensions, maxInputTokens, model); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index accba9149a46e..c92700e87cd96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.Map; @@ -32,7 +31,7 @@ public CohereEmbeddingsModel( modelId, taskType, service, - CohereServiceSettings.fromMap(serviceSettings), + CohereEmbeddingsServiceSettings.fromMap(serviceSettings), CohereEmbeddingsTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); @@ -43,7 +42,7 @@ public CohereEmbeddingsModel( String modelId, TaskType taskType, String service, - CohereServiceSettings serviceSettings, + CohereEmbeddingsServiceSettings serviceSettings, CohereEmbeddingsTaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { @@ -54,13 +53,13 @@ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskS super(model, taskSettings); } - public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) { + public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsServiceSettings serviceSettings) { super(model, serviceSettings); } @Override - public CohereServiceSettings getServiceSettings() { - return (CohereServiceSettings) super.getServiceSettings(); + public CohereEmbeddingsServiceSettings getServiceSettings() { + return (CohereEmbeddingsServiceSettings) super.getServiceSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..f8400bc168d69 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; + +public class CohereEmbeddingsServiceSettings implements ServiceSettings { + public static final String NAME = "cohere_embeddings_service_settings"; + + static final String EMBEDDING_TYPE = "embedding_type"; + + public static CohereEmbeddingsServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = CohereServiceSettings.fromMap(map); + CohereEmbeddingType embeddingTypes = extractOptionalEnum( + map, + EMBEDDING_TYPE, + ModelConfigurations.SERVICE_SETTINGS, + CohereEmbeddingType::fromString, + CohereEmbeddingType.values(), + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CohereEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes); + } + + private final CohereServiceSettings commonSettings; + private final CohereEmbeddingType embeddingType; + + public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, @Nullable CohereEmbeddingType embeddingType) { + this.commonSettings = commonSettings; + this.embeddingType = embeddingType; + } + + public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException { + commonSettings = in.readNamedWriteable(CohereServiceSettings.class); + embeddingType = in.readOptionalEnum(CohereEmbeddingType.class); + } + + public CohereServiceSettings getCommonSettings() { + return commonSettings; + } + + public CohereEmbeddingType getEmbeddingType() { + return embeddingType; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + commonSettings.toXContentFragment(builder); + builder.field(EMBEDDING_TYPE, embeddingType); + + builder.endObject(); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(commonSettings); + out.writeOptionalEnum(embeddingType); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CohereEmbeddingsServiceSettings that = (CohereEmbeddingsServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings) && Objects.equals(embeddingType, that.embeddingType); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings, embeddingType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 33bd4095c47cd..858efdb0d1ace 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -23,8 +23,6 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE; /** @@ -34,22 +32,14 @@ * See api docs for details. *

* - * @param model the id of the model to use in the requests to cohere * @param inputType Specifies the type of input you're giving to the model - * @param embeddingType Specifies the type of embeddings you want to get back (we only support retrieving a single type) * @param truncation Specifies how the API will handle inputs longer than the maximum token length */ -public record CohereEmbeddingsTaskSettings( - @Nullable String model, - @Nullable InputType inputType, - @Nullable CohereEmbeddingType embeddingType, - @Nullable CohereTruncation truncation -) implements TaskSettings { +public record CohereEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable CohereTruncation truncation) implements TaskSettings { public static final String NAME = "cohere_embeddings_task_settings"; - public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null); + public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null); static final String INPUT_TYPE = "input_type"; - static final String EMBEDDING_TYPE = "embedding_type"; public static CohereEmbeddingsTaskSettings fromMap(Map map) { if (map.isEmpty()) { @@ -58,7 +48,6 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); - String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); InputType inputType = extractOptionalEnum( map, INPUT_TYPE, @@ -67,14 +56,6 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { InputType.values(), validationException ); - CohereEmbeddingType embeddingTypes = extractOptionalEnum( - map, - EMBEDDING_TYPE, - ModelConfigurations.TASK_SETTINGS, - CohereEmbeddingType::fromString, - CohereEmbeddingType.values(), - validationException - ); CohereTruncation truncation = extractOptionalEnum( map, TRUNCATE, @@ -88,33 +69,20 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { throw validationException; } - return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation); + return new CohereEmbeddingsTaskSettings(inputType, truncation); } public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException { - this( - in.readOptionalString(), - in.readOptionalEnum(InputType.class), - in.readOptionalEnum(CohereEmbeddingType.class), - in.readOptionalEnum(CohereTruncation.class) - ); + this(in.readOptionalEnum(InputType.class), in.readOptionalEnum(CohereTruncation.class)); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (model != null) { - builder.field(MODEL, model); - } - if (inputType != null) { builder.field(INPUT_TYPE, inputType); } - if (embeddingType != null) { - builder.field(EMBEDDING_TYPE, embeddingType); - } - if (truncation != null) { builder.field(TRUNCATE, truncation); } @@ -134,18 +102,14 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(model); out.writeOptionalEnum(inputType); - out.writeOptionalEnum(embeddingType); out.writeOptionalEnum(truncation); } public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) { - var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model(); var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType(); - var embeddingTypesToUse = requestTaskSettings.embeddingType() == null ? embeddingType : requestTaskSettings.embeddingType(); var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation(); - return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse); + return new CohereEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index 6464ca0e0fda8..b3b130b22a1fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -52,9 +52,7 @@ public static HuggingFaceServiceSettings fromMap(Map map) { public static URI extractUri(Map map, String fieldName, ValidationException validationException) { String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - if (parsedUrl == null) { - return null; - } + return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java index 553a5eaf60dae..4e96ac73157ad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java @@ -50,17 +50,6 @@ public static OpenAiServiceSettings fromMap(Map map) { SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - - // Throw if any of the settings were empty strings or invalid - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - // the url is optional and only for testing - if (url == null) { - return new OpenAiServiceSettings((URI) null, organizationId, similarity, dims, maxInputTokens); - } - URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index e1daa78454d25..acd0145cbad24 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -66,7 +66,7 @@ public void shutdown() throws IOException { webServer.close(); } - public void testCreate_OpenAiEmbeddingsModel() throws IOException { + public void testCreate_CohereEmbeddingsModel() throws IOException { var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY); try (var sender = senderFactory.createSender("test_service")) { @@ -102,17 +102,14 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { var model = CohereEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + 1024, 1024, - 1024 - ); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap( "model", - InputType.SEARCH, - CohereEmbeddingType.INT8, - CohereTruncation.END + CohereEmbeddingType.FLOAT ); + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END); var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); @@ -141,7 +138,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { "input_type", "search_query", "embedding_types", - List.of("int8"), + List.of("float"), "truncate", "end" ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index d2a8aad19b4c8..86ba29c95b4bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -110,7 +111,9 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.FLOAT, sender ); @@ -185,7 +188,9 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I var action = createAction( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.START), + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.INT8, sender ); @@ -228,7 +233,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept try (var sender = mock(Sender.class)) { var thrownException = expectThrows( IllegalArgumentException.class, - () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender) + () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) ); MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); } @@ -238,7 +243,7 @@ public void testExecute_ThrowsElasticsearchException() { var sender = mock(Sender.class); doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); @@ -259,7 +264,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled return Void.TYPE; }).when(sender).send(any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); @@ -283,7 +288,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled return Void.TYPE; }).when(sender).send(any(), any()); - var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); @@ -297,7 +302,7 @@ public void testExecute_ThrowsException() { var sender = mock(Sender.class); doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); @@ -314,7 +319,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var sender = mock(Sender.class); doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any()); - var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender); + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(List.of("abc"), listener); @@ -324,8 +329,15 @@ public void testExecute_ThrowsExceptionWithNullUrl() { MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); } - private CohereEmbeddingsAction createAction(String url, String apiKey, CohereEmbeddingsTaskSettings taskSettings, Sender sender) { - var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024); + private CohereEmbeddingsAction createAction( + String url, + String apiKey, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String modelName, + @Nullable CohereEmbeddingType embeddingType, + Sender sender + ) { + var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); return new CohereEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java index 01a2290446529..8ef9ea4b0316b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java @@ -27,7 +27,9 @@ public class CohereEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { var entity = new CohereEmbeddingsRequestEntity( List.of("abc"), - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START) + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.FLOAT ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -41,7 +43,9 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throws IOException { var entity = new CohereEmbeddingsRequestEntity( List.of("abc"), - new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE) + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + "model", + CohereEmbeddingType.INT8 ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -53,7 +57,7 @@ public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throw } public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); + var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java index 903f6a9500831..a8a661149b120 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java @@ -32,7 +32,7 @@ public class CohereEmbeddingsRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws URISyntaxException, IOException { - var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); + var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null); var httpRequest = request.createRequest(); MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class)); @@ -52,7 +52,9 @@ public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOE "url", "secret", List.of("abc"), - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START) + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.FLOAT ); var httpRequest = request.createRequest(); @@ -89,7 +91,9 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th "url", "secret", List.of("abc"), - new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.END) + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + "model", + CohereEmbeddingType.INT8 ); var httpRequest = request.createRequest(); @@ -126,7 +130,9 @@ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOExcept "url", "secret", List.of("abc"), - new CohereEmbeddingsTaskSettings(null, null, null, CohereTruncation.NONE) + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + null, + null ); var httpRequest = request.createRequest(); @@ -146,11 +152,13 @@ public static CohereEmbeddingsRequest createRequest( @Nullable String url, String apiKey, List input, - CohereEmbeddingsTaskSettings taskSettings + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType ) throws URISyntaxException { var uri = url == null ? null : new URI(url); var account = new CohereAccount(uri, new SecureString(apiKey.toCharArray())); - return new CohereEmbeddingsRequest(account, input, taskSettings); + return new CohereEmbeddingsRequest(account, input, taskSettings, model, embeddingType); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index c46c100c9a476..9c154a20531f7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -46,10 +46,7 @@ private enum EmbeddingType { public static TextEmbeddingResults createRandomResults() { var embeddingType = randomFrom(EmbeddingType.values()); var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType); - - if (createFunction == null) { - createFunction = TextEmbeddingResultsTests::createRandomFloatEmbedding; - } + assert createFunction != null : "the embeddings type map is missing a value from the EmbeddingType enum"; return createRandomResults(createFunction); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index eb54745806a68..c72b161941ad2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -105,10 +105,11 @@ public void testConvertToUri_CreatesUri() { assertThat(uri.toString(), is("www.elastic.co")); } - public void testConvertToUri_ThrowsNullPointerException_WhenPassedNull() { + public void testConvertToUri_DoesNotThrowNullPointerException_WhenPassedNull() { var validation = new ValidationException(); - expectThrows(NullPointerException.class, () -> convertToUri(null, "name", "scope", validation)); + var uri = convertToUri(null, "name", "scope", validation); + assertNull(uri); assertTrue(validation.validationErrors().isEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index 21558eec87302..8f829c11c0a11 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -47,7 +47,9 @@ private static CohereServiceSettings createRandom(String url) { dims = 1536; } Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); - return new CohereServiceSettings(ServiceUtils.createUri(url), similarityMeasure, dims, maxInputTokens); + var model = randomBoolean() ? randomAlphaOfLength(15) : null; + + return new CohereServiceSettings(ServiceUtils.createOptionalUri(url), similarityMeasure, dims, maxInputTokens, model); } public void testFromMap() { @@ -55,6 +57,7 @@ public void testFromMap() { var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); var dims = 1536; var maxInputTokens = 512; + var model = "model"; var serviceSettings = CohereServiceSettings.fromMap( new HashMap<>( Map.of( @@ -65,20 +68,22 @@ public void testFromMap() { ServiceFields.DIMENSIONS, dims, ServiceFields.MAX_INPUT_TOKENS, - maxInputTokens + maxInputTokens, + CohereServiceSettings.MODEL, + model ) ) ); MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens)) + is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model)) ); } public void testFromMap_MissingUrl_DoesNotThrowException() { var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of())); - assertNull(serviceSettings.uri()); + assertNull(serviceSettings.getUri()); } public void testFromMap_EmptyUrl_ThrowsError() { @@ -139,13 +144,17 @@ protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) t return createRandomWithNonNullUrl(); } - public static Map getServiceSettingsMap(@Nullable String url) { + public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { var map = new HashMap(); if (url != null) { map.put(ServiceFields.URL, url); } + if (model != null) { + map.put(CohereServiceSettings.MODEL, model); + } + return map; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 26bc51683155c..dc3c8eafb46ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; @@ -52,7 +53,6 @@ import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; -import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -99,8 +99,8 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - getServiceSettingsMap("url"), - getTaskSettingsMap("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", CohereEmbeddingType.FLOAT), + getTaskSettingsMap(InputType.INGEST, CohereTruncation.START), getSecretSettingsMap("secret") ), Set.of() @@ -109,10 +109,12 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); MatcherAssert.assertThat( embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)) + is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START)) ); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -130,7 +132,11 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti () -> service.parseRequestConfig( "id", TaskType.SPARSE_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), + getRequestConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), Set.of() ) ); @@ -149,7 +155,11 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); + var config = getRequestConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); config.put("extra_key", "value"); var thrownException = expectThrows( @@ -171,14 +181,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var serviceSettings = getServiceSettingsMap("url"); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null); serviceSettings.put("extra_key", "value"); - var config = getRequestConfigMap( - serviceSettings, - getTaskSettingsMap("model", null, null, null), - getSecretSettingsMap("secret") - ); + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(null, null), getSecretSettingsMap("secret")); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -199,10 +205,14 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var taskSettingsMap = getTaskSettingsMap("model", null, null, null); + var taskSettingsMap = getTaskSettingsMap(InputType.INGEST, null); taskSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret")); + var config = getRequestConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + taskSettingsMap, + getSecretSettingsMap("secret") + ); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -226,7 +236,11 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap); + var config = getRequestConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty(), + secretSettingsMap + ); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -250,14 +264,18 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr var model = service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap(null), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), + getRequestConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), Set.of() ); MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri()); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -271,8 +289,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("url"), - getTaskSettingsMap("model", null, null, null), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + getTaskSettingsMap(null, null), getSecretSettingsMap("secret") ); @@ -286,8 +304,9 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null))); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -300,7 +319,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("url"), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -330,8 +349,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap(null), - getTaskSettingsMap(null, InputType.INGEST, null, null), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), + getTaskSettingsMap(InputType.INGEST, null), getSecretSettingsMap("secret") ); @@ -345,11 +364,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings(null, InputType.INGEST, null, null)) - ); + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -362,8 +378,8 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("url"), - getTaskSettingsMap("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", CohereEmbeddingType.INT8), + getTaskSettingsMap(InputType.SEARCH, CohereTruncation.NONE), getSecretSettingsMap("secret") ); persistedConfig.config().put("extra_key", "value"); @@ -378,10 +394,12 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.INT8)); MatcherAssert.assertThat( embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE)) + is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE)) ); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -397,7 +415,11 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap); + var persistedConfig = getPersistedConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty(), + secretSettingsMap + ); var model = service.parsePersistedConfigWithSecrets( "id", @@ -409,7 +431,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -423,8 +445,8 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("url"), - getTaskSettingsMap("model", null, null, null), + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + getTaskSettingsMap(null, null), getSecretSettingsMap("secret") ); persistedConfig.secrets.put("extra_key", "value"); @@ -439,8 +461,9 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null))); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -452,7 +475,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var serviceSettingsMap = getServiceSettingsMap("url"); + var serviceSettingsMap = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); @@ -467,7 +490,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -480,10 +503,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var taskSettingsMap = getTaskSettingsMap("model", InputType.SEARCH, null, null); + var taskSettingsMap = getTaskSettingsMap(InputType.SEARCH, null); taskSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret")); + var persistedConfig = getPersistedConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + taskSettingsMap, + getSecretSettingsMap("secret") + ); var model = service.parsePersistedConfigWithSecrets( "id", @@ -495,11 +522,9 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, null, null)) - ); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -512,8 +537,8 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("url"), - getTaskSettingsMap("model", null, null, CohereTruncation.NONE) + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + getTaskSettingsMap(null, CohereTruncation.NONE) ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -521,11 +546,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.NONE)) - ); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -537,7 +560,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty()); + var persistedConfig = getPersistedConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty() + ); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -551,7 +577,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro } } - public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() throws IOException { + public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { try ( var service = new CohereService( new SetOnce<>(mock(HttpRequestSenderFactory.class)), @@ -559,8 +585,8 @@ public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() ) ) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap(null), - getTaskSettingsMap("model", null, CohereEmbeddingType.FLOAT, null) + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", CohereEmbeddingType.FLOAT), + getTaskSettingsMap(null, null) ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -568,11 +594,10 @@ public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", null, CohereEmbeddingType.FLOAT, null)) - ); + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri()); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -584,7 +609,10 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty()); + var persistedConfig = getPersistedConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + getTaskSettingsMapEmpty() + ); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -592,7 +620,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); assertNull(embeddingsModel.getSecretSettings()); } @@ -605,21 +633,18 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var serviceSettingsMap = getServiceSettingsMap("url"); + var serviceSettingsMap = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(null, InputType.SEARCH, null, null)); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(InputType.SEARCH, null)); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings(null, InputType.SEARCH, null, null)) - ); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -631,21 +656,22 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( new SetOnce<>(createWithEmptySettings(threadPool)) ) ) { - var taskSettingsMap = getTaskSettingsMap("model", InputType.INGEST, null, null); + var taskSettingsMap = getTaskSettingsMap(InputType.INGEST, null); taskSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap); + var persistedConfig = getPersistedConfigMap( + CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null), + taskSettingsMap + ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null)) - ); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -712,9 +738,11 @@ public void testInfer_SendsRequest() throws IOException { var model = CohereEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null), + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), 1024, - 1024 + 1024, + "model", + null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer(model, List.of("abc"), new HashMap<>(), listener); @@ -772,7 +800,9 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, - 1 + 1, + null, + null ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -781,7 +811,17 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { MatcherAssert.assertThat( result, // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is(CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2)) + is( + CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + null, + null + ) + ) ); } } @@ -803,7 +843,9 @@ public void testInfer_UnauthorisedResponse() throws IOException { "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, - 1024 + 1024, + null, + null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer(model, List.of("abc"), new HashMap<>(), listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index d6b01486ae3ec..1961d6b168d54 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.common.SimilarityMeasure; @@ -24,34 +25,63 @@ public class CohereEmbeddingsModelTests extends ESTestCase { - public void testOverrideWith_OverridesModel() { - var model = createModel("url", "api_key", null); + public void testOverrideWith_OverridesInputType_WithSearch() { + var model = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); - var overriddenModel = model.overrideWith(getTaskSettingsMap("model", null, null, null)); - var expectedModel = createModel("url", "api_key", new CohereEmbeddingsTaskSettings("model", null, null, null), null, null); + var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null)); + var expectedModel = createModel( + "url", + "api_key", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, null), + null, + null, + "model", + CohereEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_DoesNotOverride_WhenSettingsAreEmpty() { - var model = createModel("url", "api_key", null); + var model = createModel("url", "api_key", null, null, null); var overriddenModel = model.overrideWith(Map.of()); MatcherAssert.assertThat(overriddenModel, sameInstance(model)); } public void testOverrideWith_DoesNotOverride_WhenSettingsAreNull() { - var model = createModel("url", "api_key", null); + var model = createModel("url", "api_key", null, null, null); var overriddenModel = model.overrideWith(null); MatcherAssert.assertThat(overriddenModel, sameInstance(model)); } - public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) { - return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null); + public static CohereEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType + ) { + return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model, embeddingType); } - public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) { - return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions); + public static CohereEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType + ) { + return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model, embeddingType); } public static CohereEmbeddingsModel createModel( @@ -59,13 +89,18 @@ public static CohereEmbeddingsModel createModel( String apiKey, CohereEmbeddingsTaskSettings taskSettings, @Nullable Integer tokenLimit, - @Nullable Integer dimensions + @Nullable Integer dimensions, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType ) { return new CohereEmbeddingsModel( "id", TaskType.TEXT_EMBEDDING, "service", - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit), + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model), + embeddingType + ), taskSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..8daa5a27f9618 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.embeddings; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.common.SimilarityMeasure; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + public static CohereEmbeddingsServiceSettings createRandom() { + var commonSettings = CohereServiceSettingsTests.createRandom(); + var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null; + + return new CohereEmbeddingsServiceSettings(commonSettings, embeddingType); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = CohereEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CohereServiceSettings.MODEL, + model, + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, + CohereEmbeddingType.INT8.toString() + ) + ) + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model), + CohereEmbeddingType.INT8 + ) + ) + ); + } + + public void testFromMap_MissingEmbeddingType_DoesNotThrowException() { + var serviceSettings = CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of())); + assertNull(serviceSettings.getEmbeddingType()); + } + + public void testFromMap_EmptyEmbeddingType_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, ""))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE + ) + ) + ); + } + + public void testFromMap_InvalidEmbeddingType_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [abc] received. [embedding_type] must be one of [float, int8];" + ) + ) + ); + } + + public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CohereEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, List.of("abc"))) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]") + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereEmbeddingsServiceSettings::new; + } + + @Override + protected CohereEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CohereEmbeddingsServiceSettings mutateInstance(CohereEmbeddingsServiceSettings instance) throws IOException { + return createRandom(); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap( + @Nullable String url, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType + ) { + var map = new HashMap<>(CohereServiceSettingsTests.getServiceSettingsMap(url, model)); + + if (embeddingType != null) { + map.put(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, embeddingType.toString()); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index a568a0eaa1e5a..cf16473bdb70f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.embeddings; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; @@ -19,7 +18,6 @@ import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.is; @@ -27,18 +25,16 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { public static CohereEmbeddingsTaskSettings createRandom() { - var model = randomBoolean() ? randomAlphaOfLength(15) : null; var inputType = randomBoolean() ? randomFrom(InputType.values()) : null; - var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null; var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; - return new CohereEmbeddingsTaskSettings(model, inputType, embeddingType, truncation); + return new CohereEmbeddingsTaskSettings(inputType, truncation); } public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { MatcherAssert.assertThat( CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), - is(new CohereEmbeddingsTaskSettings(null, null, null, null)) + is(new CohereEmbeddingsTaskSettings(null, null)) ); } @@ -47,18 +43,14 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { CohereEmbeddingsTaskSettings.fromMap( new HashMap<>( Map.of( - CohereServiceFields.MODEL, - "abc", CohereEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), - CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, - CohereEmbeddingType.INT8.toString(), CohereServiceFields.TRUNCATE, CohereTruncation.END.toString() ) ) ), - is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.END)) + is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.END)) ); } @@ -74,18 +66,6 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { ); } - public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() { - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc")))) - ); - - MatcherAssert.assertThat( - exception.getMessage(), - is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]") - ); - } - public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { var taskSettings = CohereEmbeddingsTaskSettings.fromMap( new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) @@ -97,7 +77,7 @@ public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { public void testOverrideWith_UsesOverriddenSettings() { var taskSettings = CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) + new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) ); var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap( @@ -105,7 +85,7 @@ public void testOverrideWith_UsesOverriddenSettings() { ); var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings); - MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.START))); + MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.START))); } @Override @@ -127,26 +107,13 @@ public static Map getTaskSettingsMapEmpty() { return new HashMap<>(); } - public static Map getTaskSettingsMap( - @Nullable String model, - @Nullable InputType inputType, - @Nullable CohereEmbeddingType embeddingType, - @Nullable CohereTruncation truncation - ) { + public static Map getTaskSettingsMap(@Nullable InputType inputType, @Nullable CohereTruncation truncation) { var map = new HashMap(); - if (model != null) { - map.put(CohereServiceFields.MODEL, model); - } - if (inputType != null) { map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); } - if (embeddingType != null) { - map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, embeddingType.toString()); - } - if (truncation != null) { map.put(CohereServiceFields.TRUNCATE, truncation.toString()); } From 83caa87d5b8196fec76142d83d9ab3fd510f203d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Jan 2024 10:57:13 -0500 Subject: [PATCH 10/13] Using separate named writeables for byte results and floats --- .../core/inference/results/ByteValue.java | 69 ------ ...{EmbeddingValue.java => EmbeddingInt.java} | 7 +- .../core/inference/results/FloatValue.java | 69 ------ .../core/inference/results/TextEmbedding.java | 18 ++ .../results/TextEmbeddingByteResults.java | 146 ++++++++++++ .../results/TextEmbeddingResults.java | 74 ++---- .../inference/results/TextEmbeddingUtils.java | 30 +++ .../InferenceNamedWriteablesProvider.java | 9 +- .../CohereEmbeddingsResponseEntity.java | 74 +++--- .../HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../OpenAiEmbeddingsResponseEntity.java | 2 +- .../inference/services/ServiceUtils.java | 35 +-- .../action/InferenceActionResponseTests.java | 6 +- .../cohere/CohereActionCreatorTests.java | 4 +- .../cohere/CohereEmbeddingsActionTests.java | 11 +- .../HuggingFaceActionCreatorTests.java | 6 +- .../openai/OpenAiActionCreatorTests.java | 10 +- .../openai/OpenAiEmbeddingsActionTests.java | 4 +- .../external/openai/OpenAiClientTests.java | 8 +- .../CohereEmbeddingsResponseEntityTests.java | 34 +-- ...gingFaceEmbeddingsResponseEntityTests.java | 20 +- .../OpenAiEmbeddingsResponseEntityTests.java | 10 +- .../inference/results/ByteValueTests.java | 35 --- .../inference/results/FloatValueTests.java | 35 --- .../TextEmbeddingByteResultsTests.java | 165 +++++++++++++ .../results/TextEmbeddingResultsTests.java | 219 +++--------------- .../inference/services/ServiceUtilsTests.java | 139 +++++++++++ .../services/cohere/CohereServiceTests.java | 4 +- .../huggingface/HuggingFaceServiceTests.java | 4 +- .../services/openai/OpenAiServiceTests.java | 4 +- 30 files changed, 668 insertions(+), 585 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{EmbeddingValue.java => EmbeddingInt.java} (59%) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java deleted file mode 100644 index 5017aa9be0d52..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Objects; - -public class ByteValue implements EmbeddingValue { - - public static final String NAME = "byte_value"; - - private final Byte value; - - public ByteValue(Byte value) { - this.value = value; - } - - public ByteValue(StreamInput in) throws IOException { - value = in.readByte(); - } - - @Override - public String toString() { - return value.toString(); - } - - @Override - public Byte getValue() { - return value; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.value(value); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ByteValue that = (ByteValue) o; - return Objects.equals(value, that.value); - } - - @Override - public int hashCode() { - return Objects.hash(value); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeByte(value); - } - - @Override - public String getWriteableName() { - return NAME; - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java similarity index 59% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java index f3fd641db395d..05fc8a3cef1b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java @@ -7,9 +7,6 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.xcontent.ToXContentFragment; - -public interface EmbeddingValue extends NamedWriteable, ToXContentFragment { - Number getValue(); +public interface EmbeddingInt { + int getSize(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java deleted file mode 100644 index 5c89720e54b2e..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Objects; - -public class FloatValue implements EmbeddingValue { - - public static final String NAME = "float_value"; - - private final Float value; - - public FloatValue(Float value) { - this.value = value; - } - - public FloatValue(StreamInput in) throws IOException { - value = in.readFloat(); - } - - @Override - public String toString() { - return value.toString(); - } - - @Override - public Number getValue() { - return value; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.value(value); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FloatValue that = (FloatValue) o; - return Objects.equals(value, that.value); - } - - @Override - public int hashCode() { - return Objects.hash(value); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeFloat(value); - } - - @Override - public String getWriteableName() { - return NAME; - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java new file mode 100644 index 0000000000000..a185c2938223e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +public interface TextEmbedding { + + /** + * Returns the first text embedding entry in the result list's array size. + * @return the size of the text embedding + * @throws IllegalStateException if the list of embeddings is empty + */ + int getFirstEmbeddingSize() throws IllegalStateException; +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java new file mode 100644 index 0000000000000..4ffef36359589 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Writes a text embedding result in the follow json format + * { + * "text_embedding": [ + * { + * "embedding": [ + * 23 + * ] + * }, + * { + * "embedding": [ + * -23 + * ] + * } + * ] + * } + */ +public record TextEmbeddingByteResults(List embeddings) implements InferenceServiceResults, TextEmbedding { + public static final String NAME = "text_embedding_service_byte_results"; + public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); + + public TextEmbeddingByteResults(StreamInput in) throws IOException { + this(in.readCollectionAsList(Embedding::new)); + } + + @Override + public int getFirstEmbeddingSize() { + return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(TEXT_EMBEDDING); + for (Embedding embedding : embeddings) { + embedding.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(embeddings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public List transformToCoordinationFormat() { + return embeddings.stream() + .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) + .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .toList(); + } + + @Override + @SuppressWarnings("deprecation") + public List transformToLegacyFormat() { + var legacyEmbedding = new LegacyTextEmbeddingResults( + embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList() + ); + + return List.of(legacyEmbedding); + } + + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); + + return map; + } + + public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt { + public static final String EMBEDDING = "embedding"; + + public Embedding(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(StreamInput::readByte)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(values, StreamOutput::writeByte); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.startArray(EMBEDDING); + for (Byte value : values) { + builder.value(value); + } + builder.endArray(); + + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public Map asMap() { + return Map.of(EMBEDDING, values); + } + + public List toFloats() { + return values.stream().map(Byte::floatValue).toList(); + } + + @Override + public int getSize() { + return values().size(); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index fadffb2782018..75eb4ebc19902 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -19,10 +18,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; /** @@ -42,7 +41,7 @@ * ] * } */ -public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults { +public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults, TextEmbedding { public static final String NAME = "text_embedding_service_results"; public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); @@ -55,11 +54,16 @@ public TextEmbeddingResults(StreamInput in) throws IOException { this( legacyTextEmbeddingResults.embeddings() .stream() - .map(embedding -> Embedding.ofFloats(embedding.values())) + .map(embedding -> new Embedding(embedding.values())) .collect(Collectors.toList()) ); } + @Override + public int getFirstEmbeddingSize() { + return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings)); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startArray(TEXT_EMBEDDING); @@ -83,7 +87,7 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { return embeddings.stream() - .map(embedding -> embedding.values.stream().mapToDouble(value -> value.getValue().doubleValue()).toArray()) + .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) .toList(); } @@ -92,7 +96,7 @@ public List transformToCoordinationFormat() { @SuppressWarnings("deprecation") public List transformToLegacyFormat() { var legacyEmbedding = new LegacyTextEmbeddingResults( - embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList() + embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList() ); return List.of(legacyEmbedding); @@ -105,52 +109,21 @@ public Map asMap() { return map; } - public static class Embedding implements Writeable, ToXContentObject { + public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt { public static final String EMBEDDING = "embedding"; - public static Embedding ofFloats(List values) { - return new Embedding(convertFloatsToEmbeddingValues(values)); - } - - public static Embedding ofBytes(List values) { - List convertedValues = values.stream().map(ByteValue::new).collect(Collectors.toList()); - - return new Embedding(convertedValues); - } - - private final List values; - - public Embedding(List values) { - this.values = values; - } - public Embedding(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { - values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class); - } else { - values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat)); - } + this(in.readCollectionAsImmutableList(StreamInput::readFloat)); } - private static List convertFloatsToEmbeddingValues(List floats) { - return floats.stream().map(FloatValue::new).collect(Collectors.toList()); - } - - public List toFloats() { - return values.stream().map(value -> value.getValue().floatValue()).toList(); - } - - public List values() { - return values; + @Override + public int getSize() { + return values.size(); } @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { - out.writeNamedWriteableCollection(values); - } else { - out.writeCollection(toFloats(), StreamOutput::writeFloat); - } + out.writeCollection(values, StreamOutput::writeFloat); } @Override @@ -158,7 +131,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.startArray(EMBEDDING); - for (EmbeddingValue value : values) { + for (Float value : values) { builder.value(value); } builder.endArray(); @@ -172,19 +145,6 @@ public String toString() { return Strings.toString(this); } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Embedding embedding = (Embedding) o; - return Objects.equals(values, embedding.values); - } - - @Override - public int hashCode() { - return Objects.hash(values); - } - public Map asMap() { return Map.of(EMBEDDING, values); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java new file mode 100644 index 0000000000000..02cb3b878c7fe --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import java.util.List; + +public class TextEmbeddingUtils { + + /** + * Returns the first text embedding entry's array size. + * @param embeddings the list of embeddings + * @return the size of the text embedding + * @throws IllegalStateException if the list of embeddings is empty + */ + public static int getFirstEmbeddingSize(List embeddings) throws IllegalStateException { + if (embeddings.isEmpty()) { + throw new IllegalStateException("Embeddings list is empty"); + } + + return embeddings.get(0).getSize(); + } + + private TextEmbeddingUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f2f604e22bc24..982c33a08d1fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -14,11 +14,9 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xpack.core.inference.results.ByteValue; -import org.elasticsearch.xpack.core.inference.results.EmbeddingValue; -import org.elasticsearch.xpack.core.inference.results.FloatValue; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings; @@ -55,8 +53,9 @@ public static List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new) ); - namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, FloatValue.NAME, FloatValue::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, ByteValue.NAME, ByteValue::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new) + ); // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index b9830b72de010..bd808c225d7e3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -11,13 +11,12 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.ByteValue; -import org.elasticsearch.xpack.core.inference.results.EmbeddingValue; -import org.elasticsearch.xpack.core.inference.results.FloatValue; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -36,11 +35,11 @@ public class CohereEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response"; - private static final Map> EMBEDDING_PARSERS = Map.of( + private static final Map> EMBEDDING_PARSERS = Map.of( toLowerCase(CohereEmbeddingType.FLOAT), - CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry, + CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray, toLowerCase(CohereEmbeddingType.INT8), - CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry + CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); @@ -132,7 +131,7 @@ private static String supportedEmbeddingTypes() { * * */ - public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -148,15 +147,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp return parseEmbeddingsObject(jsonParser); } else if (token == XContentParser.Token.START_ARRAY) { // if the request did not specify the embedding types then it will default to floats - List embeddingList = XContentParserUtils.parseList( - jsonParser, - parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray( - parser, - CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry - ) - ); - - return new TextEmbeddingResults(embeddingList); + return parseFloatEmbeddingsArray(jsonParser); } else { throwUnknownToken(token, jsonParser); } @@ -167,7 +158,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp } } - private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException { + private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser) throws IOException { XContentParser.Token token; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -178,12 +169,7 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) } parser.nextToken(); - var embeddingList = XContentParserUtils.parseList( - parser, - listParser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(listParser, embeddingValueParser) - ); - - return new TextEmbeddingResults(embeddingList); + return embeddingValueParser.apply(parser); } } @@ -195,29 +181,26 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) ); } - private static TextEmbeddingResults.Embedding parseEmbeddingsArray( - XContentParser parser, - CheckedFunction parseEntry - ) throws IOException { - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValues = XContentParserUtils.parseList(parser, parseEntry); + private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { + var embeddingList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); - return new TextEmbeddingResults.Embedding(embeddingValues); + return new TextEmbeddingByteResults(embeddingList); } - private static FloatValue parseEmbeddingFloatEntry(XContentParser parser) throws IOException { - XContentParser.Token token = parser.currentToken(); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); - return new FloatValue(parser.floatValue()); + private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry); + + return new TextEmbeddingByteResults.Embedding(embeddingValues); } - private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException { + private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { XContentParser.Token token = parser.currentToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); var parsedByte = parser.shortValue(); checkByteBounds(parsedByte); - return new ByteValue((byte) parsedByte); + return (byte) parsedByte; } private static void checkByteBounds(short value) { @@ -226,5 +209,24 @@ private static void checkByteBounds(short value) { } } + private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException { + var embeddingList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry); + + return new TextEmbeddingResults(embeddingList); + } + + private static TextEmbeddingResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry); + + return new TextEmbeddingResults.Embedding(embeddingValues); + } + + private static Float parseEmbeddingFloatEntry(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + private CohereEmbeddingsResponseEntity() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index b3706b318439b..b74b03891034f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList); - return TextEmbeddingResults.Embedding.ofFloats(embeddingValues); + return new TextEmbeddingResults.Embedding(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 0640821d0b9e1..4926ba3f0ef6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array parser.skipChildren(); - return TextEmbeddingResults.Embedding.ofFloats(embeddingValues); + return new TextEmbeddingResults.Embedding(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 20912a3811af1..7029f9ca3bf56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.TextEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.common.SimilarityMeasure; @@ -109,25 +110,6 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } - public static String mustBeNonEmptyList(String settingName, String scope) { - return Strings.format("[%s] Invalid value empty list. [%s] must be a non-empty list", scope, settingName); - } - - public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) { - return Strings.format( - "[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]", - scope, - invalidType, - invalidValue, - settingName, - requiredType - ); - } - - public static String invalidValue(String settingName, String scope, String invalidValue, String validValue) { - return invalidValue(settingName, scope, invalidValue, new String[] { validValue }); - } - public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) { return Strings.format( "[%s] Invalid value [%s] received. [%s] must be one of [%s]", @@ -287,16 +269,11 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio assert model.getTaskType() == TaskType.TEXT_EMBEDDING; service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), listener.delegateFailureAndWrap((delegate, r) -> { - if (r instanceof TextEmbeddingResults embeddingResults) { - if (embeddingResults.embeddings().isEmpty()) { - delegate.onFailure( - new ElasticsearchStatusException( - "Could not determine embedding size, no embeddings were returned in test call", - RestStatus.BAD_REQUEST - ) - ); - } else { - delegate.onResponse(embeddingResults.embeddings().get(0).values().size()); + if (r instanceof TextEmbedding embeddingResults) { + try { + delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); + } catch (Exception e) { + delegate.onFailure(new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e)); } } else { delegate.onFailure( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 622cf51798609..759411cec1212 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -46,7 +46,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceAction.Response createTestInstance() { var result = switch (randomIntBetween(0, 2)) { - case 0 -> TextEmbeddingResultsTests.createRandomFloatResults(); + case 0 -> TextEmbeddingResultsTests.createRandomResults(); case 1 -> LegacyTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults(); default -> SparseEmbeddingResultsTests.createRandomResults(); }; @@ -90,7 +90,7 @@ public void testSerializesOpenAiAddedVersion_UsingSparseEmbeddingResult() throws } public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults(); + var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), INFERENCE_MULTIPLE_INPUTS); assertOnBWCObject(copy, instance, INFERENCE_MULTIPLE_INPUTS); @@ -106,7 +106,7 @@ public void testSerializesMultipleInputsVersion_UsingSparseEmbeddingResult() thr // Technically we should never see a text embedding result in the transport version of this test because support // for it wasn't added until openai public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults(); + var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED); assertOnBWCObject(copy, instance, ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index acd0145cbad24..67a95265f093d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -39,7 +39,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -117,7 +117,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 86ba29c95b4bc..501d5a5e42bfe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; @@ -44,8 +45,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBytes; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -122,7 +122,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( @@ -199,7 +199,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 0, (byte) -1))))); + MatcherAssert.assertThat( + result.asMap(), + is(TextEmbeddingByteResultsTests.buildExpectation(List.of(List.of((byte) 0, (byte) -1)))) + ); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index c17b293c8bc35..95b69f1231e9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -331,7 +331,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(2)); { @@ -389,7 +389,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F))))); + assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F))))); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 2d7493a783c51..23b6f1ea2fbe3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -33,7 +33,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap; @@ -100,7 +100,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -169,7 +169,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(2)); { assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -252,7 +252,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(2)); { assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -320,7 +320,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 2d5cea7bd981e..6bc8e2d61d579 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -38,7 +38,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.equalTo; @@ -104,7 +104,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java index 0c4b354e7171f..bb9612f01d8ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java @@ -40,7 +40,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequestTests.createRequest; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -104,7 +104,7 @@ public void testSend_SuccessfulResponse() throws IOException, URISyntaxException var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -155,7 +155,7 @@ public void testSend_SuccessfulResponse_WithoutUser() throws IOException, URISyn var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -205,7 +205,7 @@ public void testSend_SuccessfulResponse_WithoutOrganization() throws IOException var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java index 1d7af8577a722..76aecd997414c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -8,7 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.cohere; import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -18,6 +20,7 @@ import java.nio.charset.StandardCharsets; import java.util.List; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -47,14 +50,15 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + InferenceServiceResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingResults.class)); MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + ((TextEmbeddingResults) parsedResults).embeddings(), + is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)))) ); } @@ -85,14 +89,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)))) ); } @@ -129,14 +133,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)))) + is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)))) ); } @@ -173,14 +177,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) -1, (byte) 0)))) + is(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1, (byte) 0)))) ); } @@ -213,7 +217,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -222,8 +226,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F)) + new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)), + new TextEmbeddingResults.Embedding(List.of(-0.123F, 0.123F)) ) ) ); @@ -260,7 +264,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw } """; - TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( + TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -269,8 +273,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F)) + new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)), + new TextEmbeddingResults.Embedding(List.of(-0.123F, 0.123F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index d54bcd91c9eda..2b6e11fdfafa7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -37,7 +37,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException { @@ -57,7 +57,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws IOException { @@ -83,8 +83,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws parsedResults.embeddings(), is( List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) + new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), + new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) ) ) ); @@ -115,8 +115,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) + new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), + new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) ) ) ); @@ -254,7 +254,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException { @@ -273,7 +273,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException { @@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException { @@ -309,7 +309,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 3e8b50591e3b6..010e990a3ce80 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -49,7 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F)) + new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)), + new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F)) ) ) ); @@ -261,7 +261,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { @@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F))))); + assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java deleted file mode 100644 index b89c3c9ab0ec9..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ByteValue; - -import java.io.IOException; - -public class ByteValueTests extends AbstractWireSerializingTestCase { - public static ByteValue createRandom() { - return new ByteValue(randomByte()); - } - - @Override - protected Writeable.Reader instanceReader() { - return ByteValue::new; - } - - @Override - protected ByteValue createTestInstance() { - return createRandom(); - } - - @Override - protected ByteValue mutateInstance(ByteValue instance) throws IOException { - return null; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java deleted file mode 100644 index 1d37c0d3ddee1..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.FloatValue; - -import java.io.IOException; - -public class FloatValueTests extends AbstractWireSerializingTestCase { - public static FloatValue createRandom() { - return new FloatValue(randomFloat()); - } - - @Override - protected Writeable.Reader instanceReader() { - return FloatValue::new; - } - - @Override - protected FloatValue createTestInstance() { - return createRandom(); - } - - @Override - protected FloatValue mutateInstance(FloatValue instance) throws IOException { - return null; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java new file mode 100644 index 0000000000000..b9318db6ece34 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { + public static TextEmbeddingByteResults createRandomResults() { + int embeddings = randomIntBetween(1, 10); + List embeddingResults = new ArrayList<>(embeddings); + + for (int i = 0; i < embeddings; i++) { + embeddingResults.add(createRandomEmbedding()); + } + + return new TextEmbeddingByteResults(embeddingResults); + } + + private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + int columns = randomIntBetween(1, 10); + List floats = new ArrayList<>(columns); + + for (int i = 0; i < columns; i++) { + floats.add(randomByte()); + } + + return new TextEmbeddingByteResults.Embedding(floats); + } + + public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { + var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) 23)))); + + assertThat( + entity.asMap(), + is( + Map.of( + TextEmbeddingByteResults.TEXT_EMBEDDING, + List.of(Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23))) + ) + ) + ); + + String xContentResult = Strings.toString(entity, true, true); + assertThat(xContentResult, is(""" + { + "text_embedding" : [ + { + "embedding" : [ + 23 + ] + } + ] + }""")); + } + + public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { + var entity = new TextEmbeddingByteResults( + List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) 23)), new TextEmbeddingByteResults.Embedding(List.of((byte) 24))) + + ); + + assertThat( + entity.asMap(), + is( + Map.of( + TextEmbeddingByteResults.TEXT_EMBEDDING, + List.of( + Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)), + Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 24)) + ) + ) + ) + ); + + String xContentResult = Strings.toString(entity, true, true); + assertThat(xContentResult, is(""" + { + "text_embedding" : [ + { + "embedding" : [ + 23 + ] + }, + { + "embedding" : [ + 24 + ] + } + ] + }""")); + } + + public void testTransformToCoordinationFormat() { + var results = new TextEmbeddingByteResults( + List.of( + new TextEmbeddingByteResults.Embedding(List.of((byte) 23, (byte) 24)), + new TextEmbeddingByteResults.Embedding(List.of((byte) 25, (byte) 26)) + ) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingByteResults.TEXT_EMBEDDING, + new double[] { 23F, 24F }, + false + ), + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingByteResults.TEXT_EMBEDDING, + new double[] { 25F, 26F }, + false + ) + ) + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return TextEmbeddingByteResults::new; + } + + @Override + protected TextEmbeddingByteResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException { + // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list + if (randomBoolean()) { + // -1 to remove at least one item from the list + int end = randomInt(instance.embeddings().size() - 1); + return new TextEmbeddingByteResults(instance.embeddings().subList(0, end)); + } else { + List embeddings = new ArrayList<>(instance.embeddings()); + embeddings.add(createRandomEmbedding()); + return new TextEmbeddingByteResults(embeddings); + } + } + + public static Map buildExpectation(List> embeddings) { + return Map.of( + TextEmbeddingByteResults.TEXT_EMBEDDING, + embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 9c154a20531f7..09d9894d98853 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -7,77 +7,31 @@ package org.elasticsearch.xpack.inference.results; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xpack.core.inference.results.ByteValue; -import org.elasticsearch.xpack.core.inference.results.FloatValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; -import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.function.Supplier; -import static org.elasticsearch.TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED; import static org.hamcrest.Matchers.is; -public class TextEmbeddingResultsTests extends AbstractBWCWireSerializationTestCase { - - private enum EmbeddingType { - FLOAT, - BYTE - } - - private static Map> EMBEDDING_TYPE_BUILDERS = Map.of( - EmbeddingType.FLOAT, - TextEmbeddingResultsTests::createRandomFloatEmbedding, - EmbeddingType.BYTE, - TextEmbeddingResultsTests::createRandomByteEmbedding - ); - +public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase { public static TextEmbeddingResults createRandomResults() { - var embeddingType = randomFrom(EmbeddingType.values()); - var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType); - assert createFunction != null : "the embeddings type map is missing a value from the EmbeddingType enum"; - - return createRandomResults(createFunction); - } - - public static TextEmbeddingResults createRandomFloatResults() { - return createRandomResults(TextEmbeddingResultsTests::createRandomFloatEmbedding); - } - - private static TextEmbeddingResults createRandomResults(Supplier creator) { int embeddings = randomIntBetween(1, 10); List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { - embeddingResults.add(creator.get()); + embeddingResults.add(createRandomEmbedding()); } return new TextEmbeddingResults(embeddingResults); } - private static TextEmbeddingResults.Embedding createRandomByteEmbedding() { - int columns = randomIntBetween(1, 10); - List bytes = new ArrayList<>(columns); - - for (int i = 0; i < columns; i++) { - bytes.add(randomByte()); - } - - return TextEmbeddingResults.Embedding.ofBytes(bytes); - } - - private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() { + private static TextEmbeddingResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); List floats = new ArrayList<>(columns); @@ -85,34 +39,19 @@ private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() { floats.add(randomFloat()); } - return TextEmbeddingResults.Embedding.ofFloats(floats); + return new TextEmbeddingResults.Embedding(floats); } - public static Map buildExpectationFloats(List> embeddings) { - return Map.of( - TextEmbeddingResults.TEXT_EMBEDDING, - embeddings.stream() - .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(FloatValue::new).toList())) - .toList() - ); - } + public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { + var entity = new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)))); - public static Map buildExpectationBytes(List> embeddings) { - return Map.of( - TextEmbeddingResults.TEXT_EMBEDDING, - embeddings.stream() - .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(ByteValue::new).toList())) - .toList() + assertThat( + entity.asMap(), + is(Map.of(TextEmbeddingResults.TEXT_EMBEDDING, List.of(Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F))))) ); - } - - public void testToXContent_CreatesTheRightFormatForASingleEmbedding() { - var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)))); - - MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F))))); String xContentResult = Strings.toString(entity, true, true); - MatcherAssert.assertThat(xContentResult, is(""" + assertThat(xContentResult, is(""" { "text_embedding" : [ { @@ -124,33 +63,27 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() { }""")); } - public void testToXContent_CreatesTheRightFormatForASingleEmbedding_ForBytes() { - var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)))); - - MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12))))); - - String xContentResult = Strings.toString(entity, true, true); - MatcherAssert.assertThat(xContentResult, is(""" - { - "text_embedding" : [ - { - "embedding" : [ - 12 - ] - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() { + public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { var entity = new TextEmbeddingResults( - List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)), TextEmbeddingResults.Embedding.ofFloats(List.of(0.2F))) + List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)), new TextEmbeddingResults.Embedding(List.of(0.2F))) + ); - MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F), List.of(0.2F))))); + assertThat( + entity.asMap(), + is( + Map.of( + TextEmbeddingResults.TEXT_EMBEDDING, + List.of( + Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)), + Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.2F)) + ) + ) + ) + ); String xContentResult = Strings.toString(entity, true, true); - MatcherAssert.assertThat(xContentResult, is(""" + assertThat(xContentResult, is(""" { "text_embedding" : [ { @@ -167,40 +100,12 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() { }""")); } - public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings_ForBytes() { - var entity = new TextEmbeddingResults( - List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)), TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 34))) - ); - - MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12), List.of((byte) 34))))); - - String xContentResult = Strings.toString(entity, true, true); - MatcherAssert.assertThat(xContentResult, is(""" - { - "text_embedding" : [ - { - "embedding" : [ - 12 - ] - }, - { - "embedding" : [ - 34 - ] - } - ] - }""")); - } - public void testTransformToCoordinationFormat() { var results = new TextEmbeddingResults( - List.of( - TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F, 0.2F)), - TextEmbeddingResults.Embedding.ofFloats(List.of(0.3F, 0.4F)) - ) + List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F))) ).transformToCoordinationFormat(); - MatcherAssert.assertThat( + assertThat( results, is( List.of( @@ -219,49 +124,6 @@ public void testTransformToCoordinationFormat() { ); } - public void testTransformToCoordinationFormat_FromBytes() { - var results = new TextEmbeddingResults( - List.of( - TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12, (byte) 34)), - TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 56, (byte) -78)) - ) - ).transformToCoordinationFormat(); - - MatcherAssert.assertThat( - results, - is( - List.of( - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( - TextEmbeddingResults.TEXT_EMBEDDING, - new double[] { 12F, 34F }, - false - ), - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( - TextEmbeddingResults.TEXT_EMBEDDING, - new double[] { 56F, -78F }, - false - ) - ) - ) - ); - } - - public void testSerializesToFloats_WhenVersionIsPriorToByteSupport() throws IOException { - var instance = createRandomResults(TextEmbeddingResultsTests::createRandomByteEmbedding); - var modifiedForOlderVersion = mutateInstanceForVersion(instance, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); - - var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); - assertOnBWCObject(copy, modifiedForOlderVersion, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - List entries = new ArrayList<>(); - entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); - entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); - return new NamedWriteableRegistry(entries); - } - @Override protected Writeable.Reader instanceReader() { return TextEmbeddingResults::new; @@ -281,26 +143,15 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr return new TextEmbeddingResults(instance.embeddings().subList(0, end)); } else { List embeddings = new ArrayList<>(instance.embeddings()); - embeddings.add(createRandomFloatEmbedding()); + embeddings.add(createRandomEmbedding()); return new TextEmbeddingResults(embeddings); } } - @Override - protected TextEmbeddingResults mutateInstanceForVersion(TextEmbeddingResults instance, TransportVersion version) { - if (version.before(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) { - return convertToFloatEmbeddings(instance); - } - - return instance; - } - - public TextEmbeddingResults convertToFloatEmbeddings(TextEmbeddingResults results) { - var floatEmbeddings = results.embeddings() - .stream() - .map(embedding -> TextEmbeddingResults.Embedding.ofFloats(embedding.toFloats())) - .toList(); - - return new TextEmbeddingResults(floatEmbeddings); + public static Map buildExpectation(List> embeddings) { + return Map.of( + TextEmbeddingResults.TEXT_EMBEDDING, + embeddings.stream().map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding)).toList() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index c72b161941ad2..b935c5a8c64b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -8,24 +8,45 @@ package org.elasticsearch.xpack.inference.services; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; +import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class ServiceUtilsTests extends ESTestCase { + private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); + public void testRemoveAsTypeWithTheCorrectType() { Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 1.0)); @@ -237,6 +258,124 @@ public void testExtractOptionalString_AddsException_WhenFieldIsEmpty() { assertThat(validation.validationErrors().get(0), is("[scope] Invalid value empty string. [key] must be a non-empty string")); } + public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() { + var validation = new ValidationException(); + Map map = modifiableMap(Map.of("key", "value")); + var createdEnum = extractOptionalEnum(map, "abc", "scope", InputType::fromString, InputType.values(), validation); + + assertNull(createdEnum); + assertTrue(validation.validationErrors().isEmpty()); + assertThat(map.size(), is(1)); + } + + public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidValueExists() { + var validation = new ValidationException(); + Map map = modifiableMap(Map.of("key", "invalid_value")); + var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, InputType.values(), validation); + + assertNull(createdEnum); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + assertThat( + validation.validationErrors().get(0), + is("[scope] Invalid value [invalid_value] received. [key] must be one of [ingest, search]") + ); + } + + public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() { + var service = mock(InferenceService.class); + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[3]; + listener.onResponse(new TextEmbeddingResults(List.of())); + + return Void.TYPE; + }).when(service).infer(any(), any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + getEmbeddingSize(model, service, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Could not determine embedding size")); + assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); + } + + public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmpty() { + var service = mock(InferenceService.class); + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[3]; + listener.onResponse(new TextEmbeddingByteResults(List.of())); + + return Void.TYPE; + }).when(service).infer(any(), any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + getEmbeddingSize(model, service, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Could not determine embedding size")); + assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); + } + + public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { + var service = mock(InferenceService.class); + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + var textEmbedding = TextEmbeddingResultsTests.createRandomResults(); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[3]; + listener.onResponse(textEmbedding); + + return Void.TYPE; + }).when(service).infer(any(), any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + getEmbeddingSize(model, service, listener); + + var size = listener.actionGet(TIMEOUT); + + assertThat(size, is(textEmbedding.embeddings().get(0).values().size())); + } + + public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { + var service = mock(InferenceService.class); + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults(); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[3]; + listener.onResponse(textEmbedding); + + return Void.TYPE; + }).when(service).infer(any(), any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + getEmbeddingSize(model, service, listener); + + var size = listener.actionGet(TIMEOUT); + + assertThat(size, is(textEmbedding.embeddings().get(0).values().size())); + } + private static Map modifiableMap(Map aMap) { return new HashMap<>(aMap); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index dc3c8eafb46ac..0250e08a48452 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -50,7 +50,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; @@ -749,7 +749,7 @@ public void testInfer_SendsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F))))); + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.123F, -0.123F))))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index aac50b8645993..a76cce41b4fe4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -496,7 +496,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(-0.0123F, 0.0123F))))); + assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(-0.0123F, 0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index ab4ee881a224e..394286ee5287b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -46,7 +46,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceSettingsTests.getServiceSettingsMap; @@ -717,7 +717,7 @@ public void testInfer_SendsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); From 8102408efb3501ad75a8973e8e04247fd89a38fc Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Jan 2024 11:17:42 -0500 Subject: [PATCH 11/13] Fixing a few comments and adding tests --- .../cohere/CohereEmbeddingsRequestEntity.java | 3 +- .../services/cohere/CohereServiceFields.java | 2 - .../cohere/CohereServiceSettings.java | 2 +- .../CohereEmbeddingsResponseEntityTests.java | 38 +++++++++++++++++++ .../CohereEmbeddingsTaskSettingsTests.java | 3 +- 5 files changed, 43 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java index a7d8743359f39..a0b5444ee45e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java @@ -12,6 +12,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; @@ -52,7 +53,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(TEXTS_FIELD, input); if (model != null) { - builder.field(CohereServiceFields.MODEL, model); + builder.field(CohereServiceSettings.MODEL, model); } if (taskSettings.inputType() != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java index ccfe1cb2593c6..807520637f971 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java @@ -8,7 +8,5 @@ package org.elasticsearch.xpack.inference.services.cohere; public class CohereServiceFields { - public static final String MODEL = "model"; public static final String TRUNCATE = "truncate"; - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index f03371593a340..7964741d90343 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -46,7 +46,7 @@ public static CohereServiceSettings fromMap(Map map) { Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); - String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException); + String model = extractOptionalString(map, MODEL, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java index 76aecd997414c..f04715be0838f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -188,6 +188,44 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir ); } + public void testFromResponse_ParsesBytes() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "int8": [ + [ + -1, + 0 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + + TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1, (byte) 0)))) + ); + } + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index cf16473bdb70f..164d3998f138f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.hamcrest.MatcherAssert; @@ -68,7 +69,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() { var taskSettings = CohereEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) + new HashMap<>(Map.of(CohereServiceSettings.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString())) ); var overriddenTaskSettings = taskSettings.overrideWith(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS); From 6c8343ffa12edee0abccc3d52a0c42512e012dbd Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Jan 2024 11:46:12 -0500 Subject: [PATCH 12/13] Fixing mutation issue --- .../inference/services/cohere/CohereServiceSettingsTests.java | 2 +- .../cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index 8f829c11c0a11..6f47d5c74d81c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -141,7 +141,7 @@ protected CohereServiceSettings createTestInstance() { @Override protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) throws IOException { - return createRandomWithNonNullUrl(); + return null; } public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 8daa5a27f9618..e0b29ce9c34da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -140,7 +140,7 @@ protected CohereEmbeddingsServiceSettings createTestInstance() { @Override protected CohereEmbeddingsServiceSettings mutateInstance(CohereEmbeddingsServiceSettings instance) throws IOException { - return createRandom(); + return null; } @Override From de682b5568463085df6c115b60e8eeb576baa239 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Jan 2024 15:16:15 -0500 Subject: [PATCH 13/13] Removing cohere service settings from named writeable registry --- .../xpack/inference/InferenceNamedWriteablesProvider.java | 3 --- .../cohere/embeddings/CohereEmbeddingsServiceSettings.java | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 982c33a08d1fc..c23e245b5696c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -98,9 +98,6 @@ public static List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) ); - namedWriteables.add( - new NamedWriteableRegistry.Entry(CohereServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new) - ); namedWriteables.add( new NamedWriteableRegistry.Entry( ServiceSettings.class, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java index f8400bc168d69..5327bcbcf22dd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java @@ -57,7 +57,7 @@ public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, @Nu } public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException { - commonSettings = in.readNamedWriteable(CohereServiceSettings.class); + commonSettings = new CohereServiceSettings(in); embeddingType = in.readOptionalEnum(CohereEmbeddingType.class); } @@ -92,7 +92,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(commonSettings); + commonSettings.writeTo(out); out.writeOptionalEnum(embeddingType); }