From 244da7584fbaa2618a4befbb37b13f3173d76e83 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 10 Jan 2024 14:00:15 -0500
Subject: [PATCH 01/13] Starting cohere
---
.../cohere/CohereEmbeddingsRequest.java | 41 +++++++++++
.../inference/services/ServiceUtils.java | 71 +++++++++++++++++++
.../cohere/CohereEmbeddingsTaskSettings.java | 60 ++++++++++++++++
.../services/cohere/CohereServiceFields.java | 14 ++++
.../services/cohere/CohereTruncation.java | 58 +++++++++++++++
5 files changed, 244 insertions(+)
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
new file mode 100644
index 0000000000000..093ce9fd2f4bb
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.apache.http.client.methods.HttpRequestBase;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.net.URI;
+
+public class CohereEmbeddingsRequest implements Request {
+
+ public CohereEmbeddingsRequest(Truncator truncator, Object account, Truncator.TruncationResult input, Object taskSettings) {
+
+ }
+
+ @Override
+ public HttpRequestBase createRequest() {
+ return null;
+ }
+
+ @Override
+ public URI getURI() {
+ return null;
+ }
+
+ @Override
+ public Request truncate() {
+ return null;
+ }
+
+ @Override
+ public boolean[] getTruncationInfo() {
+ return new boolean[0];
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 1686cd32d4a6b..b503c0800d342 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
@@ -21,6 +22,7 @@
import java.net.URI;
import java.net.URISyntaxException;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -105,6 +107,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}
+ public static String invalidType(String settingName, String scope, String invalidType, String requiredType) {
+ return Strings.format("[%s] Invalid type [%s] received. [%s] must be type [%s]", scope, invalidType, settingName, requiredType);
+ }
+
// TODO improve URI validation logic
public static URI convertToUri(String url, String settingName, String settingScope, ValidationException validationException) {
try {
@@ -154,6 +160,42 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin
return null;
}
+ @SuppressWarnings("unchecked")
+ public static List extractOptionalListOfType(
+ Map map,
+ String settingName,
+ String scope,
+ Class type,
+ ValidationException validationException
+ ) {
+ List> listField = ServiceUtils.removeAsType(map, settingName, List.class);
+
+ if (listField == null) {
+ return null;
+ }
+
+ if (listField.isEmpty()) {
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope));
+ return null;
+ }
+
+ List castedList = new ArrayList<>(listField.size());
+
+ for (Object listEntry : listField) {
+ if (type.isAssignableFrom(listEntry.getClass()) == false) {
+ // TODO should we just throw here like removeAsType
+ validationException.addValidationError(
+ invalidType(settingName, scope, listEntry.getClass().getSimpleName(), type.getSimpleName())
+ );
+ return null;
+ }
+
+ castedList.add((T) listEntry);
+ }
+
+ return castedList;
+ }
+
public static String extractRequiredString(
Map map,
String settingName,
@@ -194,6 +236,35 @@ public static String extractOptionalString(
return optionalField;
}
+ public static T extractOptionalEnum(
+ Map map,
+ String settingName,
+ String scope,
+ CheckedFunction converter,
+ ValidationException validationException
+ ) {
+ var s = extractOptionalString(map, settingName, scope, validationException);
+ if (s == null) {
+ return null;
+ }
+
+ try {
+ var e = converter.apply(s);
+ } catch (IllegalArgumentException e) {
+ validationException.addValidationError(invalidType(settingName, scope, s, ))
+ }
+
+ if (s.isEmpty()) {
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope));
+ }
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ return null;
+ }
+
+ return optionalField;
+ }
+
public static String parsePersistedConfigErrorMsg(String modelId, String serviceName) {
return format("Failed to parse stored model [%s] for [%s] service, please delete and add the service again", modelId, serviceName);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
new file mode 100644
index 0000000000000..229d81e468d3f
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
+
+/**
+ * Defines the task settings for the cohere text embeddings service.
+ *
+ * @param model the id of the model to use in the requests to cohere
+ * @param inputType Specifies the type of input you're giving to the model
+ * @param embeddingTypes Specifies the types of embeddings you want to get back
+ * @param truncate Specifies how the API will handle inputs longer than the maximum token length
+ */
+public record CohereEmbeddingsTaskSettings(
+ @Nullable String model,
+ @Nullable String inputType,
+ @Nullable List embeddingTypes,
+ @Nullable CohereTruncation truncate
+) {
+
+ public static final String NAME = "cohere_embeddings_task_settings";
+ static final String INPUT_TYPE = "input_type";
+ static final String EMBEDDING_TYPES = "embedding_types";
+
+ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+
+ String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
+ String inputType = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException);
+ List embeddingTypes = extractOptionalListOfType(
+ map,
+ EMBEDDING_TYPES,
+ ModelConfigurations.TASK_SETTINGS,
+ String.class,
+ validationException
+ );
+ CohereTruncation truncation = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CohereEmbeddingsTaskSettings(model, user);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
new file mode 100644
index 0000000000000..ccfe1cb2593c6
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+public class CohereServiceFields {
+ public static final String MODEL = "model";
+ public static final String TRUNCATE = "truncate";
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
new file mode 100644
index 0000000000000..ebf1d349e0b7a
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Locale;
+
+/**
+ * Defines the type of truncation for a cohere request. The specified value determines how the Cohere API will handle inputs
+ * longer than the maximum token length.
+ *
+ *
+ * See api docs for details.
+ *
+ */
+public enum CohereTruncation implements Writeable {
+ /**
+ * When the input exceeds the maximum input token length an error will be returned.
+ */
+ NONE,
+ /**
+ * Discard the start of the input
+ */
+ START,
+ /**
+ * Discard the end of the input
+ */
+ END;
+
+ public static String NAME = "cohere_truncate";
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+
+ public static CohereTruncation fromString(String name) {
+ return valueOf(name.trim().toUpperCase(Locale.ROOT));
+ }
+
+ public static CohereTruncation fromStream(StreamInput in) throws IOException {
+ return in.readEnum(CohereTruncation.class);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeEnum(this);
+ }
+}
From e1008cf3c1d63aa78658e9c64e77f4df503742ef Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Fri, 12 Jan 2024 16:57:03 -0500
Subject: [PATCH 02/13] Making progress on cohere
---
.../org/elasticsearch/TransportVersions.java | 1 +
.../elasticsearch/inference/InputType.java | 8 +-
.../org/elasticsearch/inference/Model.java | 14 ++
.../inference/ModelConfigurations.java | 26 +++
.../inference/src/main/java/module-info.java | 5 -
.../action/cohere/CohereEmbeddingsAction.java | 89 +++++++++
.../action/cohere/OpenAiActionCreator.java | 36 ++++
.../action/cohere/OpenAiActionVisitor.java | 17 ++
.../external/cohere/CohereAccount.java | 21 +++
.../cohere/CohereResponseHandler.java | 99 ++++++++++
.../http/retry/ResponseHandlerUtils.java | 22 +++
.../openai/OpenAiResponseHandler.java | 10 +-
.../external/request/RequestUtils.java | 19 ++
.../cohere/CohereEmbeddingsRequest.java | 60 +++++-
.../cohere/CohereEmbeddingsRequestEntity.java | 54 ++++++
.../external/request/cohere/CohereUtils.java | 16 ++
.../openai/OpenAiEmbeddingsRequest.java | 13 +-
.../CohereEmbeddingsResponseEntity.java | 117 ++++++++++++
.../cohere/CohereErrorResponseEntity.java | 58 ++++++
.../inference/services/ServiceUtils.java | 57 ++++--
.../cohere/CohereEmbeddingsTaskSettings.java | 60 ------
.../services/cohere/CohereModel.java | 31 ++++
.../cohere/CohereServiceSettings.java | 171 ++++++++++++++++++
.../services/cohere/CohereTruncation.java | 4 +-
.../embeddings/CohereEmbeddingsModel.java | 79 ++++++++
.../CohereEmbeddingsTaskSettings.java | 138 ++++++++++++++
.../services/openai/OpenAiModel.java | 10 +
.../openai/OpenAiServiceSettings.java | 10 +-
.../embeddings/OpenAiEmbeddingsModel.java | 22 +--
.../cohere/CohereServiceSettingsTests.java | 151 ++++++++++++++++
.../CohereEmbeddingsModelTests.java | 41 +++++
.../CohereEmbeddingsTaskSettingsTests.java | 112 ++++++++++++
.../OpenAiEmbeddingsTaskSettingsTests.java | 15 +-
33 files changed, 1440 insertions(+), 146 deletions(-)
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index f289a7a3c89a1..a872b5ceb3496 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -182,6 +182,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_PLAN_POINT_LITERAL_WKB = def(8_570_00_0);
public static final TransportVersion HOT_THREADS_AS_BYTES = def(8_571_00_0);
public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED = def(8_572_00_0);
+ public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_573_00_0);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java
index f8bbea4ae121f..b5ba3f0d1b506 100644
--- a/server/src/main/java/org/elasticsearch/inference/InputType.java
+++ b/server/src/main/java/org/elasticsearch/inference/InputType.java
@@ -29,12 +29,16 @@ public String toString() {
return name().toLowerCase(Locale.ROOT);
}
+ public static InputType fromString(String name) {
+ return valueOf(name.trim().toUpperCase(Locale.ROOT));
+ }
+
public static InputType fromStream(StreamInput in) throws IOException {
- return in.readEnum(InputType.class);
+ return in.readOptionalEnum(InputType.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeEnum(this);
+ out.writeOptionalEnum(this);
}
}
diff --git a/server/src/main/java/org/elasticsearch/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java
index 02be39d8a653d..e5f1d7936e605 100644
--- a/server/src/main/java/org/elasticsearch/inference/Model.java
+++ b/server/src/main/java/org/elasticsearch/inference/Model.java
@@ -23,6 +23,20 @@ public Model(ModelConfigurations configurations, ModelSecrets secrets) {
this.secrets = Objects.requireNonNull(secrets);
}
+ public Model(Model model, TaskSettings taskSettings) {
+ Objects.requireNonNull(model);
+
+ configurations = ModelConfigurations.of(model, taskSettings);
+ secrets = model.getSecrets();
+ }
+
+ public Model(Model model, ServiceSettings serviceSettings) {
+ Objects.requireNonNull(model);
+
+ configurations = ModelConfigurations.of(model, serviceSettings);
+ secrets = model.getSecrets();
+ }
+
public Model(ModelConfigurations configurations) {
this(configurations, new ModelSecrets());
}
diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java
index cdccca7eb0c0e..e91f373d55d37 100644
--- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java
+++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java
@@ -27,6 +27,32 @@ public class ModelConfigurations implements ToXContentObject, VersionedNamedWrit
public static final String TASK_SETTINGS = "task_settings";
private static final String NAME = "inference_model";
+ public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
+ Objects.requireNonNull(model);
+ Objects.requireNonNull(taskSettings);
+
+ return new ModelConfigurations(
+ model.getConfigurations().getModelId(),
+ model.getConfigurations().getTaskType(),
+ model.getConfigurations().getService(),
+ model.getServiceSettings(),
+ taskSettings
+ );
+ }
+
+ public static ModelConfigurations of(Model model, ServiceSettings serviceSettings) {
+ Objects.requireNonNull(model);
+ Objects.requireNonNull(serviceSettings);
+
+ return new ModelConfigurations(
+ model.getConfigurations().getModelId(),
+ model.getConfigurations().getTaskType(),
+ model.getConfigurations().getService(),
+ serviceSettings,
+ model.getTaskSettings()
+ );
+ }
+
private final String modelId;
private final TaskType taskType;
private final String service;
diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java
index 3879a0a344e06..2d25a48117778 100644
--- a/x-pack/plugin/inference/src/main/java/module-info.java
+++ b/x-pack/plugin/inference/src/main/java/module-info.java
@@ -22,10 +22,5 @@
exports org.elasticsearch.xpack.inference.registry;
exports org.elasticsearch.xpack.inference.rest;
exports org.elasticsearch.xpack.inference.services;
- exports org.elasticsearch.xpack.inference.external.http.sender;
- exports org.elasticsearch.xpack.inference.external.http;
- exports org.elasticsearch.xpack.inference.services.elser;
- exports org.elasticsearch.xpack.inference.services.huggingface.elser;
- exports org.elasticsearch.xpack.inference.services.openai;
exports org.elasticsearch.xpack.inference;
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
new file mode 100644
index 0000000000000..569f33b256803
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
+import org.elasticsearch.xpack.inference.external.cohere.CohereResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.cohere.CohereEmbeddingsRequest;
+import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class CohereEmbeddingsAction implements ExecutableAction {
+ private static final Logger logger = LogManager.getLogger(CohereEmbeddingsAction.class);
+
+ private final CohereAccount account;
+ private final CohereEmbeddingsModel model;
+ private final String errorMessage;
+ private final Truncator truncator;
+ private final RetryingHttpSender sender;
+
+ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) {
+ this.model = Objects.requireNonNull(model);
+ this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey());
+ this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri());
+ this.truncator = Objects.requireNonNull(serviceComponents.truncator());
+ this.sender = new RetryingHttpSender(
+ Objects.requireNonNull(sender),
+ serviceComponents.throttlerManager(),
+ logger,
+ new RetrySettings(serviceComponents.settings()),
+ serviceComponents.threadPool()
+ );
+ }
+
+ private static String getErrorMessage(@Nullable URI uri) {
+ if (uri != null) {
+ return format("Failed to send Cohere embeddings request to [%s]", uri.toString());
+ }
+
+ return "Failed to send Cohere embeddings request";
+ }
+
+ @Override
+ public void execute(List input, ActionListener listener) {
+ try {
+ // TODO only truncate if the setting is NONE?
+ var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
+
+ CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(truncator, account, truncatedInput, model.getTaskSettings());
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
+
+ sender.send(request, wrappedListener);
+ } catch (ElasticsearchException e) {
+ listener.onFailure(e);
+ } catch (Exception e) {
+ listener.onFailure(createInternalServerError(e, errorMessage));
+ }
+ }
+
+ private static ResponseHandler createEmbeddingsHandler() {
+ return new CohereResponseHandler("cohere text embedding", OpenAiEmbeddingsResponseEntity::fromResponse);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java
new file mode 100644
index 0000000000000..0353922959e0b
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
+
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type.
+ */
+public class OpenAiActionCreator implements OpenAiActionVisitor {
+ private final Sender sender;
+ private final ServiceComponents serviceComponents;
+
+ public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
+ this.sender = Objects.requireNonNull(sender);
+ this.serviceComponents = Objects.requireNonNull(serviceComponents);
+ }
+
+ @Override
+ public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) {
+ var overriddenModel = model.overrideWith(taskSettings);
+
+ return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java
new file mode 100644
index 0000000000000..a3a4cfbbfc873
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java
@@ -0,0 +1,17 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
+
+import java.util.Map;
+
+public interface OpenAiActionVisitor {
+ ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings);
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java
new file mode 100644
index 0000000000000..9847d496d14ee
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereAccount.java
@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.cohere;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+
+import java.net.URI;
+import java.util.Objects;
+
+public record CohereAccount(@Nullable URI url, SecureString apiKey) {
+
+ public CohereAccount {
+ Objects.requireNonNull(apiKey);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
new file mode 100644
index 0000000000000..299cad20fb012
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
@@ -0,0 +1,99 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.cohere;
+
+import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+
+import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
+import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
+
+public class CohereResponseHandler extends BaseResponseHandler {
+
+ static final String MONTHLY_REQUESTS_LIMIT = "x-endpoint-monthly-call-limit";
+ // TODO determine the production versions of these
+ static final String TRIAL_REQUEST_LIMIT_PER_MINUTE = "x-trial-endpoint-call-limit";
+ static final String TRIAL_REQUESTS_REMAINING = "x-trial-endpoint-call-remaining";
+ static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
+ static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
+
+ public CohereResponseHandler(String requestType, ResponseParser parseFunction) {
+ super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
+ }
+
+ @Override
+ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, HttpRequestBase request, HttpResult result)
+ throws RetryException {
+ checkForFailureStatusCode(request, result);
+ checkForEmptyBody(throttlerManager, logger, request, result);
+ }
+
+ /**
+ * Validates the status code throws an RetryException if not in the range [200, 300).
+ *
+ * The OpenAI API error codes are documented here.
+ * @param request The http request
+ * @param result The http response and body
+ * @throws RetryException Throws if status code is {@code >= 300 or < 200 }
+ */
+ void checkForFailureStatusCode(HttpRequestBase request, HttpResult result) throws RetryException {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+ if (statusCode >= 200 && statusCode < 300) {
+ return;
+ }
+
+ // handle error codes
+ if (statusCode >= 500) {
+ throw new RetryException(false, buildError(SERVER_ERROR, request, result));
+ } else if (statusCode == 429) {
+ throw new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
+ } else if (isTextsArrayTooLarge(result)) {
+ throw new RetryException(false, buildError(TEXTS_ARRAY_ERROR_MESSAGE, request, result));
+ } else if (statusCode == 401) {
+ throw new RetryException(false, buildError(AUTHENTICATION, request, result));
+ } else if (statusCode >= 300 && statusCode < 400) {
+ throw new RetryException(false, buildError(REDIRECTION, request, result));
+ } else {
+ throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
+ }
+ }
+
+ private static boolean isTextsArrayTooLarge(HttpResult result) {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+
+ if (statusCode == 400) {
+ var errorEntity = CohereErrorResponseEntity.fromResponse(result);
+ return errorEntity != null && errorEntity.getErrorMessage().contains(TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER);
+ }
+
+ return false;
+ }
+
+ static String buildRateLimitErrorMessage(HttpResult result) {
+ var response = result.response();
+ var monthlyRequestLimit = getFirstHeaderOrUnknown(response, MONTHLY_REQUESTS_LIMIT);
+ var trialRequestsPerMinute = getFirstHeaderOrUnknown(response, TRIAL_REQUEST_LIMIT_PER_MINUTE);
+ var trialRequestsRemaining = getFirstHeaderOrUnknown(response, TRIAL_REQUESTS_REMAINING);
+
+ var usageMessage = Strings.format(
+ "Monthly request limit [%s], permitted requests per minute [%s], remaining requests [%s]",
+ monthlyRequestLimit,
+ trialRequestsPerMinute,
+ trialRequestsRemaining
+ );
+
+ return RATE_LIMIT + ". " + usageMessage;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java
new file mode 100644
index 0000000000000..6269c81d4ceb7
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandlerUtils.java
@@ -0,0 +1,22 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.retry;
+
+import org.apache.http.HttpResponse;
+
+public class ResponseHandlerUtils {
+ public static String getFirstHeaderOrUnknown(HttpResponse response, String name) {
+ var header = response.getFirstHeader(name);
+ if (header != null && header.getElements().length > 0) {
+ return header.getElements()[0].getName();
+ }
+ return "unknown";
+ }
+
+ private ResponseHandlerUtils() {}
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java
index 207e3c2bbd035..6e54ba1bd90b5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.openai;
-import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Strings;
@@ -20,6 +19,7 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
+import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
public class OpenAiResponseHandler extends BaseResponseHandler {
/**
@@ -110,12 +110,4 @@ static String buildRateLimitErrorMessage(HttpResult result) {
return RATE_LIMIT + ". " + usageMessage;
}
-
- private static String getFirstHeaderOrUnknown(HttpResponse response, String name) {
- var header = response.getFirstHeader(name);
- if (header != null && header.getElements().length > 0) {
- return header.getElements()[0].getName();
- }
- return "unknown";
- }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
index 355db7288dacc..4cb32b7bc95fd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
@@ -10,7 +10,14 @@
import org.apache.http.Header;
import org.apache.http.HttpHeaders;
import org.apache.http.message.BasicHeader;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.CheckedSupplier;
+import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.rest.RestStatus;
+
+import java.net.URI;
+import java.net.URISyntaxException;
public class RequestUtils {
@@ -18,5 +25,17 @@ public static Header createAuthBearerHeader(SecureString apiKey) {
return new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.toString());
}
+ public static URI buildUri(URI accountUri, String service, CheckedSupplier uriBuilder) {
+ try {
+ return accountUri == null ? uriBuilder.get() : accountUri;
+ } catch (URISyntaxException e) {
+ throw new ElasticsearchStatusException(
+ Strings.format("Failed to construct %s URL", service),
+ RestStatus.INTERNAL_SERVER_ERROR,
+ e
+ );
+ }
+ }
+
private RequestUtils() {}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
index 093ce9fd2f4bb..7f2d34e1f8b3b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
@@ -7,35 +7,85 @@
package org.elasticsearch.xpack.inference.external.request.cohere;
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
public class CohereEmbeddingsRequest implements Request {
- public CohereEmbeddingsRequest(Truncator truncator, Object account, Truncator.TruncationResult input, Object taskSettings) {
+ private final Truncator truncator;
+ private final CohereAccount account;
+ private final Truncator.TruncationResult truncationResult;
+ private final URI uri;
+ private final CohereEmbeddingsTaskSettings taskSettings;
+ public CohereEmbeddingsRequest(
+ Truncator truncator,
+ CohereAccount account,
+ Truncator.TruncationResult input,
+ CohereEmbeddingsTaskSettings taskSettings
+ ) {
+ this.truncator = Objects.requireNonNull(truncator);
+ this.account = Objects.requireNonNull(account);
+ this.truncationResult = Objects.requireNonNull(input);
+ this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
+ this.taskSettings = Objects.requireNonNull(taskSettings);
}
@Override
public HttpRequestBase createRequest() {
- return null;
+ HttpPost httpPost = new HttpPost(uri);
+
+ ByteArrayEntity byteEntity = new ByteArrayEntity(
+ Strings.toString(new CohereEmbeddingsRequestEntity(truncationResult.input(), taskSettings)).getBytes(StandardCharsets.UTF_8)
+ );
+ httpPost.setEntity(byteEntity);
+
+ httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+ httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
+
+ return httpPost;
}
@Override
public URI getURI() {
- return null;
+ return uri;
}
@Override
public Request truncate() {
- return null;
+ // TODO only do this is the truncate setting is NONE?
+ var truncatedInput = truncator.truncate(truncationResult.input());
+
+ return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings);
}
@Override
public boolean[] getTruncationInfo() {
- return new boolean[0];
+ return truncationResult.truncated().clone();
+ }
+
+ // default for testing
+ static URI buildDefaultUri() throws URISyntaxException {
+ return new URIBuilder().setScheme("https")
+ .setHost(CohereUtils.HOST)
+ .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH)
+ .build();
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
new file mode 100644
index 0000000000000..8aea434d7339d
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject {
+
+ private static final String TEXTS_FIELD = "texts";
+
+ static final String INPUT_TYPE_FIELD = "input_type";
+ static final String EMBEDDING_TYPES_FIELD = "embedding_types";
+
+ public CohereEmbeddingsRequestEntity {
+ Objects.requireNonNull(input);
+ Objects.requireNonNull(taskSettings);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(TEXTS_FIELD, input);
+ if (taskSettings.model() != null) {
+ builder.field(CohereServiceFields.MODEL, taskSettings.model());
+ }
+
+ if (taskSettings.inputType() != null) {
+ builder.field(INPUT_TYPE_FIELD, taskSettings.inputType());
+ }
+
+ if (taskSettings.embeddingTypes() != null) {
+ builder.field(EMBEDDING_TYPES_FIELD, taskSettings.embeddingTypes());
+ }
+
+ if (taskSettings.truncation() != null) {
+ builder.field(EMBEDDING_TYPES_FIELD, taskSettings.truncation());
+ }
+
+ builder.endObject();
+ return builder;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java
new file mode 100644
index 0000000000000..f8ccd91d4e3d2
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+public class CohereUtils {
+ public static final String HOST = "api.cohere.ai";
+ public static final String VERSION_1 = "v1";
+ public static final String EMBEDDINGS_PATH = "embed";
+
+ private CohereUtils() {}
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java
index 3a9fab44aa04e..9ead692b9e110 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java
@@ -12,9 +12,7 @@
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount;
@@ -26,6 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
@@ -46,18 +45,10 @@ public OpenAiEmbeddingsRequest(
this.truncator = Objects.requireNonNull(truncator);
this.account = Objects.requireNonNull(account);
this.truncationResult = Objects.requireNonNull(input);
- this.uri = buildUri(this.account.url());
+ this.uri = buildUri(this.account.url(), "OpenAI", OpenAiEmbeddingsRequest::buildDefaultUri);
this.taskSettings = Objects.requireNonNull(taskSettings);
}
- private static URI buildUri(URI accountUri) {
- try {
- return accountUri == null ? buildDefaultUri() : accountUri;
- } catch (URISyntaxException e) {
- throw new ElasticsearchStatusException("Failed to construct OpenAI URL", RestStatus.INTERNAL_SERVER_ERROR, e);
- }
- }
-
public HttpRequestBase createRequest() {
HttpPost httpPost = new HttpPost(uri);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
new file mode 100644
index 0000000000000..cbe70ad8d919a
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
@@ -0,0 +1,117 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+
+public class CohereEmbeddingsResponseEntity {
+ private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response";
+
+ /**
+ * Parses the OpenAI json response.
+ * For a request like:
+ *
+ *
+ *
+ * {
+ * "texts": ["hello this is my name", "I wish I was there!"]
+ * }
+ *
+ *
+ *
+ * The response would look like:
+ *
+ *
+ *
+ * {
+ * "id": "da4f9ea6-37e4-41ab-b5e1-9e2985609555",
+ * "texts": [
+ * "hello",
+ * "awesome"
+ * ],
+ * "embeddings": [
+ * [
+ * 123
+ * ],
+ * [
+ * 123
+ * ]
+ * ],
+ * "meta": {
+ * "api_version": {
+ * "version": "1"
+ * },
+ * "warnings": [
+ * "default model on embed will be deprecated in the future, please specify a model in the request."
+ * ],
+ * "billed_units": {
+ * "input_tokens": 3
+ * }
+ * },
+ * "response_type": "embeddings_floats"
+ * }
+ *
+ *
+ */
+ public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException {
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ moveToFirstToken(jsonParser);
+
+ XContentParser.Token token = jsonParser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+
+ positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+ List embeddingList = XContentParserUtils.parseList(
+ jsonParser,
+ CohereEmbeddingsResponseEntity::parseEmbeddingObject
+ );
+
+ return new TextEmbeddingResults(embeddingList);
+ }
+ }
+
+ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+
+ positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+ List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingList);
+
+ // the parser is currently sitting at an ARRAY_END so go to the next token
+ parser.nextToken();
+ // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
+ parser.skipChildren();
+
+ return new TextEmbeddingResults.Embedding(embeddingValues);
+ }
+
+ private static float parseEmbeddingList(XContentParser parser) throws IOException {
+ XContentParser.Token token = parser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
+ return parser.floatValue();
+ }
+
+ private CohereEmbeddingsResponseEntity() {}
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
new file mode 100644
index 0000000000000..6f947b45955be
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage;
+
+public class CohereErrorResponseEntity implements ErrorMessage {
+
+ private final String errorMessage;
+
+ private CohereErrorResponseEntity(String errorMessage) {
+ this.errorMessage = errorMessage;
+ }
+
+ public String getErrorMessage() {
+ return errorMessage;
+ }
+
+ /**
+ * An example error response for invalid auth would look like
+ *
+ * {
+ * "message": "invalid request: total number of texts must be at most 96 - received 97"
+ * }
+ *
+ *
+ *
+ * @param response The error response
+ * @return An error entity if the response is JSON with the above structure
+ * or null if the response does not contain the message field
+ */
+ public static CohereErrorResponseEntity fromResponse(HttpResult response) {
+ try (
+ XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
+ .createParser(XContentParserConfiguration.EMPTY, response.body())
+ ) {
+ var responseMap = jsonParser.map();
+ var message = (String) responseMap.get("message");
+ if (message != null) {
+ return new CohereErrorResponseEntity(message);
+ }
+ } catch (Exception e) {
+ // swallow the error
+ }
+
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index b503c0800d342..6fef9dac6095a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -23,7 +23,9 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.Objects;
@@ -107,8 +109,29 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}
- public static String invalidType(String settingName, String scope, String invalidType, String requiredType) {
- return Strings.format("[%s] Invalid type [%s] received. [%s] must be type [%s]", scope, invalidType, settingName, requiredType);
+ public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) {
+ return Strings.format(
+ "[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]",
+ scope,
+ invalidType,
+ invalidValue,
+ settingName,
+ requiredType
+ );
+ }
+
+ public static String invalidValue(String settingName, String scope, String invalidValue, String validValue) {
+ return invalidValue(settingName, scope, invalidValue, new String[] { validValue });
+ }
+
+ public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) {
+ return Strings.format(
+ "[%s] Invalid value [%s] received. [%s] must be one of [%s]",
+ scope,
+ invalidType,
+ settingName,
+ String.join(", ", requiredTypes)
+ );
}
// TODO improve URI validation logic
@@ -131,6 +154,14 @@ public static URI createUri(String url) throws IllegalArgumentException {
}
}
+ public static URI createOptionalUri(String url) {
+ if (url == null) {
+ return null;
+ }
+
+ return createUri(url);
+ }
+
public static SecureString extractRequiredSecureString(
Map map,
String settingName,
@@ -185,7 +216,7 @@ public static List extractOptionalListOfType(
if (type.isAssignableFrom(listEntry.getClass()) == false) {
// TODO should we just throw here like removeAsType
validationException.addValidationError(
- invalidType(settingName, scope, listEntry.getClass().getSimpleName(), type.getSimpleName())
+ invalidType(settingName, scope, listEntry.getClass().getSimpleName(), listEntry.toString(), type.getSimpleName())
);
return null;
}
@@ -241,28 +272,22 @@ public static T extractOptionalEnum(
String settingName,
String scope,
CheckedFunction converter,
+ T[] validTypes,
ValidationException validationException
) {
- var s = extractOptionalString(map, settingName, scope, validationException);
- if (s == null) {
+ var enumString = extractOptionalString(map, settingName, scope, validationException);
+ if (enumString == null) {
return null;
}
+ var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
try {
- var e = converter.apply(s);
+ return converter.apply(enumString);
} catch (IllegalArgumentException e) {
- validationException.addValidationError(invalidType(settingName, scope, s, ))
- }
-
- if (s.isEmpty()) {
- validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope));
- }
-
- if (validationException.validationErrors().isEmpty() == false) {
- return null;
+ validationException.addValidationError(invalidValue(settingName, scope, enumString, validTypesAsStrings));
}
- return optionalField;
+ return null;
}
public static String parsePersistedConfigErrorMsg(String modelId, String serviceName) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
deleted file mode 100644
index 229d81e468d3f..0000000000000
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsTaskSettings.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.cohere;
-
-import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.inference.ModelConfigurations;
-
-import java.util.List;
-import java.util.Map;
-
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
-
-/**
- * Defines the task settings for the cohere text embeddings service.
- *
- * @param model the id of the model to use in the requests to cohere
- * @param inputType Specifies the type of input you're giving to the model
- * @param embeddingTypes Specifies the types of embeddings you want to get back
- * @param truncate Specifies how the API will handle inputs longer than the maximum token length
- */
-public record CohereEmbeddingsTaskSettings(
- @Nullable String model,
- @Nullable String inputType,
- @Nullable List embeddingTypes,
- @Nullable CohereTruncation truncate
-) {
-
- public static final String NAME = "cohere_embeddings_task_settings";
- static final String INPUT_TYPE = "input_type";
- static final String EMBEDDING_TYPES = "embedding_types";
-
- public static CohereEmbeddingsTaskSettings fromMap(Map map) {
- ValidationException validationException = new ValidationException();
-
- String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
- String inputType = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException);
- List embeddingTypes = extractOptionalListOfType(
- map,
- EMBEDDING_TYPES,
- ModelConfigurations.TASK_SETTINGS,
- String.class,
- validationException
- );
- CohereTruncation truncation = extractOptionalString(map, INPUT_TYPE, ModelConfigurations.TASK_SETTINGS, validationException);
-
- if (validationException.validationErrors().isEmpty() == false) {
- throw validationException;
- }
-
- return new CohereEmbeddingsTaskSettings(model, user);
- }
-}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
new file mode 100644
index 0000000000000..d92ea419faef0
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+
+public abstract class CohereModel extends Model {
+ public CohereModel(ModelConfigurations configurations, ModelSecrets secrets) {
+ super(configurations, secrets);
+ }
+
+ protected CohereModel(CohereModel model, TaskSettings taskSettings) {
+ super(model, taskSettings);
+ }
+
+ protected CohereModel(CohereModel model, ServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ }
+
+ public abstract ExecutableAction accept();
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
new file mode 100644
index 0000000000000..bdd5e981b0490
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
+
+public class CohereServiceSettings implements ServiceSettings {
+ public static final String NAME = "cohere_service_settings";
+
+ public static CohereServiceSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+
+ String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
+ Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
+
+ // Throw if any of the settings were empty strings or invalid
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ // the url is optional and only for testing
+ if (url == null) {
+ return new CohereServiceSettings((URI) null, similarity, dims, maxInputTokens);
+ }
+
+ URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CohereServiceSettings(uri, similarity, dims, maxInputTokens);
+ }
+
+ private final URI uri;
+ private final SimilarityMeasure similarity;
+ private final Integer dimensions;
+ private final Integer maxInputTokens;
+
+ public CohereServiceSettings(
+ @Nullable URI uri,
+ @Nullable SimilarityMeasure similarity,
+ @Nullable Integer dimensions,
+ @Nullable Integer maxInputTokens
+ ) {
+ this.uri = uri;
+ this.similarity = similarity;
+ this.dimensions = dimensions;
+ this.maxInputTokens = maxInputTokens;
+ }
+
+ public CohereServiceSettings(
+ @Nullable String url,
+ @Nullable SimilarityMeasure similarity,
+ @Nullable Integer dimensions,
+ @Nullable Integer maxInputTokens
+ ) {
+ this(createOptionalUri(url), similarity, dimensions, maxInputTokens);
+ }
+
+ public CohereServiceSettings(StreamInput in) throws IOException {
+ uri = createOptionalUri(in.readOptionalString());
+ similarity = in.readOptionalEnum(SimilarityMeasure.class);
+ dimensions = in.readOptionalVInt();
+ maxInputTokens = in.readOptionalVInt();
+ }
+
+ public URI uri() {
+ return uri;
+ }
+
+ public SimilarityMeasure similarity() {
+ return similarity;
+ }
+
+ public Integer dimensions() {
+ return dimensions;
+ }
+
+ public Integer maxInputTokens() {
+ return maxInputTokens;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ if (uri != null) {
+ builder.field(URL, uri.toString());
+ }
+ if (similarity != null) {
+ builder.field(SIMILARITY, similarity);
+ }
+ if (dimensions != null) {
+ builder.field(DIMENSIONS, dimensions);
+ }
+ if (maxInputTokens != null) {
+ builder.field(MAX_INPUT_TOKENS, maxInputTokens);
+ }
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ var uriToWrite = uri != null ? uri.toString() : null;
+ out.writeOptionalString(uriToWrite);
+ out.writeOptionalEnum(similarity);
+ out.writeOptionalVInt(dimensions);
+ out.writeOptionalVInt(maxInputTokens);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CohereServiceSettings that = (CohereServiceSettings) o;
+ return Objects.equals(uri, that.uri)
+ && Objects.equals(similarity, that.similarity)
+ && Objects.equals(dimensions, that.dimensions)
+ && Objects.equals(maxInputTokens, that.maxInputTokens);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(uri, similarity, dimensions, maxInputTokens);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
index ebf1d349e0b7a..3fec21a9e03b0 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
@@ -48,11 +48,11 @@ public static CohereTruncation fromString(String name) {
}
public static CohereTruncation fromStream(StreamInput in) throws IOException {
- return in.readEnum(CohereTruncation.class);
+ return in.readOptionalEnum(CohereTruncation.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeEnum(this);
+ out.writeOptionalEnum(this);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
new file mode 100644
index 0000000000000..ed1295c8ee31c
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.util.Map;
+
+public class CohereEmbeddingsModel extends CohereModel {
+ public CohereEmbeddingsModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets
+ ) {
+ this(
+ modelId,
+ taskType,
+ service,
+ CohereServiceSettings.fromMap(serviceSettings),
+ CohereEmbeddingsTaskSettings.fromMap(taskSettings),
+ DefaultSecretSettings.fromMap(secrets)
+ );
+ }
+
+ // should only be used for testing
+ CohereEmbeddingsModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ CohereServiceSettings serviceSettings,
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable DefaultSecretSettings secretSettings
+ ) {
+ super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
+ }
+
+ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskSettings taskSettings) {
+ super(model, taskSettings);
+ }
+
+ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ }
+
+ @Override
+ public CohereServiceSettings getServiceSettings() {
+ return (CohereServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public CohereEmbeddingsTaskSettings getTaskSettings() {
+ return (CohereEmbeddingsTaskSettings) super.getTaskSettings();
+ }
+
+ @Override
+ public DefaultSecretSettings getSecretSettings() {
+ return (DefaultSecretSettings) super.getSecretSettings();
+ }
+
+ @Override
+ public ExecutableAction accept() {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
new file mode 100644
index 0000000000000..c2ad8b2b8e958
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
@@ -0,0 +1,138 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
+import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE;
+
+/**
+ * Defines the task settings for the cohere text embeddings service.
+ *
+ *
+ * See api docs for details.
+ *
+ *
+ * @param model the id of the model to use in the requests to cohere
+ * @param inputType Specifies the type of input you're giving to the model
+ * @param embeddingTypes Specifies the types of embeddings you want to get back
+ * @param truncation Specifies how the API will handle inputs longer than the maximum token length
+ */
+public record CohereEmbeddingsTaskSettings(
+ @Nullable String model,
+ @Nullable InputType inputType,
+ @Nullable List embeddingTypes,
+ @Nullable CohereTruncation truncation
+) implements TaskSettings {
+
+ public static final String NAME = "cohere_embeddings_task_settings";
+ static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null);
+ static final String INPUT_TYPE = "input_type";
+ static final String EMBEDDING_TYPES = "embedding_types";
+
+ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
+ if (map.isEmpty()) {
+ return EMPTY_SETTINGS;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
+ InputType inputType = extractOptionalEnum(
+ map,
+ INPUT_TYPE,
+ ModelConfigurations.TASK_SETTINGS,
+ InputType::fromString,
+ InputType.values(),
+ validationException
+ );
+ List embeddingTypes = extractOptionalListOfType(
+ map,
+ EMBEDDING_TYPES,
+ ModelConfigurations.TASK_SETTINGS,
+ String.class,
+ validationException
+ );
+ CohereTruncation truncation = extractOptionalEnum(
+ map,
+ TRUNCATE,
+ ModelConfigurations.TASK_SETTINGS,
+ CohereTruncation::fromString,
+ CohereTruncation.values(),
+ validationException
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation);
+ }
+
+ public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException {
+ this(in.readOptionalString(), InputType.fromStream(in), in.readOptionalStringCollectionAsList(), CohereTruncation.fromStream(in));
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (model != null) {
+ builder.field(MODEL, model);
+ }
+
+ if (inputType != null) {
+ builder.field(INPUT_TYPE, inputType);
+ }
+
+ if (embeddingTypes != null) {
+ builder.field(EMBEDDING_TYPES, embeddingTypes);
+ }
+
+ if (truncation != null) {
+ builder.field(TRUNCATE, truncation);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalString(model);
+ inputType.writeTo(out);
+ out.writeOptionalStringCollection(embeddingTypes);
+ truncation.writeTo(out);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java
index 97823e3bc9079..1e158725f531d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java
@@ -10,6 +10,8 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
@@ -21,5 +23,13 @@ public OpenAiModel(ModelConfigurations configurations, ModelSecrets secrets) {
super(configurations, secrets);
}
+ protected OpenAiModel(OpenAiModel model, TaskSettings taskSettings) {
+ super(model, taskSettings);
+ }
+
+ protected OpenAiModel(OpenAiModel model, ServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ }
+
public abstract ExecutableAction accept(OpenAiActionVisitor creator, Map taskSettings);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
index 5ade2aad0acb4..553a5eaf60dae 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
@@ -28,7 +28,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
@@ -100,14 +100,6 @@ public OpenAiServiceSettings(
this(createOptionalUri(uri), organizationId, similarity, dimensions, maxInputTokens);
}
- private static URI createOptionalUri(String url) {
- if (url == null) {
- return null;
- }
-
- return createUri(url);
- }
-
public OpenAiServiceSettings(StreamInput in) throws IOException {
uri = createOptionalUri(in.readOptionalString());
organizationId = in.readOptionalString();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java
index 250837d895590..0df3d126402c7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java
@@ -52,29 +52,11 @@ public OpenAiEmbeddingsModel(
}
private OpenAiEmbeddingsModel(OpenAiEmbeddingsModel originalModel, OpenAiEmbeddingsTaskSettings taskSettings) {
- super(
- new ModelConfigurations(
- originalModel.getConfigurations().getModelId(),
- originalModel.getConfigurations().getTaskType(),
- originalModel.getConfigurations().getService(),
- originalModel.getServiceSettings(),
- taskSettings
- ),
- new ModelSecrets(originalModel.getSecretSettings())
- );
+ super(originalModel, taskSettings);
}
public OpenAiEmbeddingsModel(OpenAiEmbeddingsModel originalModel, OpenAiServiceSettings serviceSettings) {
- super(
- new ModelConfigurations(
- originalModel.getConfigurations().getModelId(),
- originalModel.getConfigurations().getTaskType(),
- originalModel.getConfigurations().getService(),
- serviceSettings,
- originalModel.getTaskSettings()
- ),
- new ModelSecrets(originalModel.getSecretSettings())
- );
+ super(originalModel, serviceSettings);
}
@Override
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
new file mode 100644
index 0000000000000..21558eec87302
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
@@ -0,0 +1,151 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase {
+
+ public static CohereServiceSettings createRandomWithNonNullUrl() {
+ return createRandom(randomAlphaOfLength(15));
+ }
+
+ /**
+ * The created settings can have a url set to null.
+ */
+ public static CohereServiceSettings createRandom() {
+ var url = randomBoolean() ? randomAlphaOfLength(15) : null;
+ return createRandom(url);
+ }
+
+ private static CohereServiceSettings createRandom(String url) {
+ SimilarityMeasure similarityMeasure = null;
+ Integer dims = null;
+ var isTextEmbeddingModel = randomBoolean();
+ if (isTextEmbeddingModel) {
+ similarityMeasure = SimilarityMeasure.DOT_PRODUCT;
+ dims = 1536;
+ }
+ Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
+ return new CohereServiceSettings(ServiceUtils.createUri(url), similarityMeasure, dims, maxInputTokens);
+ }
+
+ public void testFromMap() {
+ var url = "https://www.abc.com";
+ var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
+ var dims = 1536;
+ var maxInputTokens = 512;
+ var serviceSettings = CohereServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ ServiceFields.URL,
+ url,
+ ServiceFields.SIMILARITY,
+ similarity,
+ ServiceFields.DIMENSIONS,
+ dims,
+ ServiceFields.MAX_INPUT_TOKENS,
+ maxInputTokens
+ )
+ )
+ );
+
+ MatcherAssert.assertThat(
+ serviceSettings,
+ is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens))
+ );
+ }
+
+ public void testFromMap_MissingUrl_DoesNotThrowException() {
+ var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of()));
+ assertNull(serviceSettings.uri());
+ }
+
+ public void testFromMap_EmptyUrl_ThrowsError() {
+ var thrownException = expectThrows(
+ ValidationException.class,
+ () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")))
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ containsString(
+ Strings.format(
+ "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
+ ServiceFields.URL
+ )
+ )
+ );
+ }
+
+ public void testFromMap_InvalidUrl_ThrowsError() {
+ var url = "https://www.abc^.com";
+ var thrownException = expectThrows(
+ ValidationException.class,
+ () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)))
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL))
+ );
+ }
+
+ public void testFromMap_InvalidSimilarity_ThrowsError() {
+ var similarity = "by_size";
+ var thrownException = expectThrows(
+ ValidationException.class,
+ () -> CohereServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.SIMILARITY, similarity)))
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Validation Failed: 1: [service_settings] Unknown similarity measure [by_size];")
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CohereServiceSettings::new;
+ }
+
+ @Override
+ protected CohereServiceSettings createTestInstance() {
+ return createRandomWithNonNullUrl();
+ }
+
+ @Override
+ protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) throws IOException {
+ return createRandomWithNonNullUrl();
+ }
+
+ public static Map getServiceSettingsMap(@Nullable String url) {
+ var map = new HashMap();
+
+ if (url != null) {
+ map.put(ServiceFields.URL, url);
+ }
+
+ return map;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
new file mode 100644
index 0000000000000..be3f77251fd08
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+public class CohereEmbeddingsModelTests extends ESTestCase {
+
+ public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) {
+ return new CohereEmbeddingsModel(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ "service",
+ new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit),
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+ );
+ }
+
+ public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) {
+ return new CohereEmbeddingsModel(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ "service",
+ new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit),
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
new file mode 100644
index 0000000000000..b5e44b88b9b4c
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase {
+
+ public static CohereEmbeddingsTaskSettings createRandom() {
+ var model = randomBoolean() ? randomAlphaOfLength(15) : null;
+ var inputType = randomBoolean() ? randomFrom(InputType.values()) : null;
+ var embeddingTypes = randomBoolean() ? List.of(randomAlphaOfLength(6)) : null;
+ var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
+
+ return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation);
+ }
+
+ public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
+ MatcherAssert.assertThat(
+ CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())),
+ is(new CohereEmbeddingsTaskSettings(null, null, null, null))
+ );
+ }
+
+ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
+ MatcherAssert.assertThat(
+ CohereEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ CohereServiceFields.MODEL,
+ "abc",
+ CohereEmbeddingsTaskSettings.INPUT_TYPE,
+ InputType.INGEST.toString().toLowerCase(Locale.ROOT),
+ CohereEmbeddingsTaskSettings.EMBEDDING_TYPES,
+ List.of("abc", "123"),
+ CohereServiceFields.TRUNCATE,
+ CohereTruncation.END.toString().toLowerCase(Locale.ROOT)
+ )
+ )
+ ),
+ is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, List.of("abc", "123"), CohereTruncation.END))
+ );
+ }
+
+ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
+ var exception = expectThrows(
+ ValidationException.class,
+ () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.INPUT_TYPE, "abc")))
+ );
+
+ MatcherAssert.assertThat(
+ exception.getMessage(),
+ is("Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [ingest, search];")
+ );
+ }
+
+ public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
+ var exception = expectThrows(
+ ValidationException.class,
+ () -> CohereEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc", 123)))
+ )
+ );
+
+ MatcherAssert.assertThat(
+ exception.getMessage(),
+ is(
+ "Validation Failed: 1: [task_settings] Invalid type [Integer]"
+ + " received for value [123]. [embedding_types] must be type [String];"
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CohereEmbeddingsTaskSettings::new;
+ }
+
+ @Override
+ protected CohereEmbeddingsTaskSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettings instance) throws IOException {
+ return null;
+ }
+
+ public static Map getTaskSettingsMap() {
+ return new HashMap<>(Collections.emptyMap());
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java
index d33ec12016cad..f297eb622c421 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java
@@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.HashMap;
@@ -39,7 +40,7 @@ public void testFromMap_MissingModel_ThrowException() {
() -> OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.USER, "user")))
);
- assertThat(
+ MatcherAssert.assertThat(
thrownException.getMessage(),
is(
Strings.format(
@@ -55,14 +56,14 @@ public void testFromMap_CreatesWithModelAndUser() {
new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model", OpenAiEmbeddingsTaskSettings.USER, "user"))
);
- assertThat(taskSettings.model(), is("model"));
- assertThat(taskSettings.user(), is("user"));
+ MatcherAssert.assertThat(taskSettings.model(), is("model"));
+ MatcherAssert.assertThat(taskSettings.user(), is("user"));
}
public void testFromMap_MissingUser_DoesNotThrowException() {
var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiEmbeddingsTaskSettings.MODEL, "model")));
- assertThat(taskSettings.model(), is("model"));
+ MatcherAssert.assertThat(taskSettings.model(), is("model"));
assertNull(taskSettings.user());
}
@@ -72,7 +73,7 @@ public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
);
var overriddenTaskSettings = taskSettings.overrideWith(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS);
- assertThat(overriddenTaskSettings, is(taskSettings));
+ MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
}
public void testOverrideWith_UsesOverriddenSettings() {
@@ -85,7 +86,7 @@ public void testOverrideWith_UsesOverriddenSettings() {
);
var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings);
- assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user2")));
+ MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user2")));
}
public void testOverrideWith_UsesOnlyNonNullModelSetting() {
@@ -98,7 +99,7 @@ public void testOverrideWith_UsesOnlyNonNullModelSetting() {
);
var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings);
- assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user")));
+ MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("model2", "user")));
}
@Override
From f94c32c66a988ea15edee6e7aca64bc312aea107 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 17 Jan 2024 09:54:41 -0500
Subject: [PATCH 03/13] Filling out the embedding types
---
.../core/inference/results/ByteValue.java | 64 +++++++++++
.../inference/results/EmbeddingValue.java | 15 +++
.../core/inference/results/FloatValue.java | 64 +++++++++++
.../results/TextEmbeddingResults.java | 59 ++++++++--
.../InferenceNamedWriteablesProvider.java | 5 +
...nCreator.java => CohereActionCreator.java} | 8 +-
...nVisitor.java => CohereActionVisitor.java} | 6 +-
.../action/cohere/CohereEmbeddingsAction.java | 15 +--
.../cohere/CohereEmbeddingsRequest.java | 31 +++---
.../cohere/CohereEmbeddingsRequestEntity.java | 2 +-
.../CohereEmbeddingsResponseEntity.java | 104 ++++++++++++++++--
.../HuggingFaceEmbeddingsResponseEntity.java | 2 +-
.../OpenAiEmbeddingsResponseEntity.java | 2 +-
.../inference/services/ServiceUtils.java | 54 ++++++++-
.../embeddings/CohereEmbeddingType.java | 57 ++++++++++
.../embeddings/CohereEmbeddingsModel.java | 9 ++
.../CohereEmbeddingsTaskSettings.java | 27 ++++-
.../CohereEmbeddingsModelTests.java | 50 +++++++--
.../CohereEmbeddingsTaskSettingsTests.java | 68 ++++++++++--
19 files changed, 562 insertions(+), 80 deletions(-)
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/{OpenAiActionCreator.java => CohereActionCreator.java} (80%)
rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/{OpenAiActionVisitor.java => CohereActionVisitor.java} (69%)
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
new file mode 100644
index 0000000000000..9e3dc866d304e
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class ByteValue implements EmbeddingValue {
+
+ public static final String NAME = "byte_value";
+
+ private final Byte value;
+
+ public ByteValue(Byte value) {
+ this.value = value;
+ }
+
+ public ByteValue(StreamInput in) throws IOException {
+ value = in.readByte();
+ }
+
+ @Override
+ public Byte getValue() {
+ return value;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.value(value);
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ByteValue that = (ByteValue) o;
+ return Objects.equals(value, that.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(value);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeByte(value);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java
new file mode 100644
index 0000000000000..f3fd641db395d
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.xcontent.ToXContentFragment;
+
+public interface EmbeddingValue extends NamedWriteable, ToXContentFragment {
+ Number getValue();
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
new file mode 100644
index 0000000000000..9ead125d2d48a
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class FloatValue implements EmbeddingValue {
+
+ public static final String NAME = "float_value";
+
+ private final Float value;
+
+ public FloatValue(Float value) {
+ this.value = value;
+ }
+
+ public FloatValue(StreamInput in) throws IOException {
+ value = in.readFloat();
+ }
+
+ @Override
+ public Number getValue() {
+ return value;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.value(value);
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ FloatValue that = (FloatValue) o;
+ return Objects.equals(value, that.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(value);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeFloat(value);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
index ace5974866038..9fc99ba3f03f8 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.inference.results;
+import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -21,6 +22,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.stream.Collectors;
/**
@@ -51,10 +53,7 @@ public TextEmbeddingResults(StreamInput in) throws IOException {
@SuppressWarnings("deprecation")
TextEmbeddingResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) {
this(
- legacyTextEmbeddingResults.embeddings()
- .stream()
- .map(embedding -> new Embedding(embedding.values()))
- .collect(Collectors.toList())
+ legacyTextEmbeddingResults.embeddings().stream().map(embedding -> Embedding.of(embedding.values())).collect(Collectors.toList())
);
}
@@ -81,7 +80,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
+ .map(embedding -> embedding.values.stream().mapToDouble(value -> value.getValue().doubleValue()).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.toList();
}
@@ -90,7 +89,7 @@ public List extends InferenceResults> transformToCoordinationFormat() {
@SuppressWarnings("deprecation")
public List extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
- embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList()
+ embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList()
);
return List.of(legacyEmbedding);
@@ -103,16 +102,43 @@ public Map asMap() {
return map;
}
- public record Embedding(List values) implements Writeable, ToXContentObject {
+ public static class Embedding implements Writeable, ToXContentObject {
public static final String EMBEDDING = "embedding";
+ public static Embedding of(List values) {
+ return new Embedding(convertFloatsToEmbeddingValues(values));
+ }
+
+ private final List values;
+
+ public Embedding(List values) {
+ this.values = values;
+ }
+
public Embedding(StreamInput in) throws IOException {
- this(in.readCollectionAsImmutableList(StreamInput::readFloat));
+ if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
+ values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class);
+ } else {
+ values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat));
+ }
+ }
+
+ private static List convertFloatsToEmbeddingValues(List floats) {
+ return floats.stream().map(FloatValue::new).collect(Collectors.toList());
+ }
+
+ public List toFloats() {
+ return values.stream().map(value -> value.getValue().floatValue()).toList();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeCollection(values, StreamOutput::writeFloat);
+ if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
+ out.writeNamedWriteableCollection(values);
+ } else {
+ // TODO do we need to check that the values are floats here?
+ out.writeCollection(toFloats(), StreamOutput::writeFloat);
+ }
}
@Override
@@ -120,7 +146,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.startArray(EMBEDDING);
- for (Float value : values) {
+ for (EmbeddingValue value : values) {
builder.value(value);
}
builder.endArray();
@@ -134,6 +160,19 @@ public String toString() {
return Strings.toString(this);
}
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Embedding embedding = (Embedding) o;
+ return Objects.equals(values, embedding.values);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(values);
+ }
+
public Map asMap() {
return Map.of(EMBEDDING, values);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index c632c568fea16..a6d6c3cbee4d5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -14,6 +14,9 @@
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xpack.core.inference.results.ByteValue;
+import org.elasticsearch.xpack.core.inference.results.EmbeddingValue;
+import org.elasticsearch.xpack.core.inference.results.FloatValue;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
@@ -49,6 +52,8 @@ public static List getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
+ namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, FloatValue.NAME, FloatValue::new));
+ namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, ByteValue.NAME, ByteValue::new));
// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
similarity index 80%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
index 0353922959e0b..b3ac96979b68b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionCreator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
@@ -10,7 +10,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import java.util.Map;
import java.util.Objects;
@@ -18,17 +18,17 @@
/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type.
*/
-public class OpenAiActionCreator implements OpenAiActionVisitor {
+public class CohereActionCreator implements CohereActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;
- public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
+ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
@Override
- public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) {
+ public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) {
var overriddenModel = model.overrideWith(taskSettings);
return new CohereEmbeddingsAction(sender, overriddenModel, serviceComponents);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java
similarity index 69%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java
index a3a4cfbbfc873..1500d48e3c201 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/OpenAiActionVisitor.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java
@@ -8,10 +8,10 @@
package org.elasticsearch.xpack.inference.external.action.cohere;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import java.util.Map;
-public interface OpenAiActionVisitor {
- ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings);
+public interface CohereActionVisitor {
+ ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
index 569f33b256803..cc84b06a7ba8f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
@@ -13,7 +13,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
import org.elasticsearch.xpack.inference.external.cohere.CohereResponseHandler;
@@ -22,7 +21,7 @@
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.cohere.CohereEmbeddingsRequest;
-import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
+import org.elasticsearch.xpack.inference.external.response.cohere.CohereEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
@@ -31,24 +30,22 @@
import java.util.Objects;
import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class CohereEmbeddingsAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(CohereEmbeddingsAction.class);
+ private static final ResponseHandler HANDLER = createEmbeddingsHandler();
private final CohereAccount account;
private final CohereEmbeddingsModel model;
private final String errorMessage;
- private final Truncator truncator;
private final RetryingHttpSender sender;
public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey());
this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri());
- this.truncator = Objects.requireNonNull(serviceComponents.truncator());
this.sender = new RetryingHttpSender(
Objects.requireNonNull(sender),
serviceComponents.throttlerManager(),
@@ -70,12 +67,12 @@ private static String getErrorMessage(@Nullable URI uri) {
public void execute(List input, ActionListener listener) {
try {
// TODO only truncate if the setting is NONE?
- var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
+ // var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
- CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(truncator, account, truncatedInput, model.getTaskSettings());
+ CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings());
ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
- sender.send(request, wrappedListener);
+ sender.send(request, HANDLER, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
@@ -84,6 +81,6 @@ public void execute(List input, ActionListener
}
private static ResponseHandler createEmbeddingsHandler() {
- return new CohereResponseHandler("cohere text embedding", OpenAiEmbeddingsResponseEntity::fromResponse);
+ return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
index 7f2d34e1f8b3b..ba0e3d258dba9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
@@ -14,7 +14,6 @@
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
@@ -22,6 +21,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
+import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
@@ -29,21 +29,14 @@
public class CohereEmbeddingsRequest implements Request {
- private final Truncator truncator;
private final CohereAccount account;
- private final Truncator.TruncationResult truncationResult;
+ private final List input;
private final URI uri;
private final CohereEmbeddingsTaskSettings taskSettings;
- public CohereEmbeddingsRequest(
- Truncator truncator,
- CohereAccount account,
- Truncator.TruncationResult input,
- CohereEmbeddingsTaskSettings taskSettings
- ) {
- this.truncator = Objects.requireNonNull(truncator);
+ public CohereEmbeddingsRequest(CohereAccount account, List input, CohereEmbeddingsTaskSettings taskSettings) {
this.account = Objects.requireNonNull(account);
- this.truncationResult = Objects.requireNonNull(input);
+ this.input = Objects.requireNonNull(input);
this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
this.taskSettings = Objects.requireNonNull(taskSettings);
}
@@ -53,7 +46,7 @@ public HttpRequestBase createRequest() {
HttpPost httpPost = new HttpPost(uri);
ByteArrayEntity byteEntity = new ByteArrayEntity(
- Strings.toString(new CohereEmbeddingsRequestEntity(truncationResult.input(), taskSettings)).getBytes(StandardCharsets.UTF_8)
+ Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
@@ -70,15 +63,21 @@ public URI getURI() {
@Override
public Request truncate() {
- // TODO only do this is the truncate setting is NONE?
- var truncatedInput = truncator.truncate(truncationResult.input());
- return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings);
+ return this;
+ // TODO only do this is the truncate setting is NONE?
+ // var truncatedInput = truncator.truncate(truncationResult.input());
+ //
+ // return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings);
}
@Override
public boolean[] getTruncationInfo() {
- return truncationResult.truncated().clone();
+ return null;
+ }
+
+ public CohereEmbeddingsTaskSettings getTaskSettings() {
+ return taskSettings;
}
// default for testing
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
index 8aea434d7339d..7331426b5c683 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
@@ -45,7 +45,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
if (taskSettings.truncation() != null) {
- builder.field(EMBEDDING_TYPES_FIELD, taskSettings.truncation());
+ builder.field(CohereServiceFields.TRUNCATE, taskSettings.truncation());
}
builder.endObject();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
index cbe70ad8d919a..8f0eabeda3a04 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
@@ -9,23 +9,38 @@
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ByteValue;
+import org.elasticsearch.xpack.core.inference.results.EmbeddingValue;
+import org.elasticsearch.xpack.core.inference.results.FloatValue;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import java.io.IOException;
import java.util.List;
+import java.util.Map;
+import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType.toLowerCase;
public class CohereEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response";
+ private static final Map> EMBEDDING_PARSERS = Map.of(
+ toLowerCase(CohereEmbeddingType.FLOAT),
+ CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry,
+ toLowerCase(CohereEmbeddingType.INT8),
+ CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry
+ );
+
/**
* Parses the OpenAI json response.
* For a request like:
@@ -81,23 +96,77 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp
XContentParser.Token token = jsonParser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
- positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
+ // TODO can't really rely on the texts field being before embeddings, this will need to be a loop
+ // we don't return the is_truncated result for text embeddings so we don't really need this yet
+ positionParserAtTokenAfterField(jsonParser, "texts", FAILED_TO_FIND_FIELD_TEMPLATE);
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, jsonParser.currentToken(), jsonParser);
- List embeddingList = XContentParserUtils.parseList(
+ List inputAfterTruncation = XContentParserUtils.parseList(
jsonParser,
- CohereEmbeddingsResponseEntity::parseEmbeddingObject
+ CohereEmbeddingsResponseEntity::parseTruncatedTextsField
);
- return new TextEmbeddingResults(embeddingList);
+ positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
+ // TODO we don't need this yet, just require that the embedding type be a single string so that
+ // this is always the second style
+ token = jsonParser.currentToken();
+ if (token == XContentParser.Token.START_OBJECT) {
+ return parseEmbeddingsObject(jsonParser);
+ } else if (token == XContentParser.Token.START_ARRAY) {
+ List embeddingList = XContentParserUtils.parseList(
+ jsonParser,
+ parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(
+ parser,
+ CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry
+ )
+ );
+
+ return new TextEmbeddingResults(embeddingList);
+ } else {
+ throwUnknownToken(token, jsonParser);
+ }
+
+ // This should never be reached. The above code should either return successfully or hit the throwUnknownToken
+ // or throw a parsing exception
+ throw new IllegalStateException("Reached an invalid state while parsing the Cohere response");
}
}
- private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
- XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+ private static String parseTruncatedTextsField(XContentParser parser) throws IOException {
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
+ return parser.text();
+ }
+
+ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException {
+ XContentParser.Token token;
+
+ while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
+ if (token == XContentParser.Token.FIELD_NAME) {
+ var embeddingValueParser = EMBEDDING_PARSERS.get(parser.currentName());
+ if (embeddingValueParser == null) {
+ continue;
+ }
+
+ parser.nextToken();
+ var embeddingList = XContentParserUtils.parseList(
+ parser,
+ listParser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(listParser, embeddingValueParser)
+ );
+
+ return new TextEmbeddingResults(embeddingList);
+ }
+ }
+
+ throw new IllegalStateException("Failed to find a supported embedding type");
+ }
- positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
+ private static TextEmbeddingResults.Embedding parseEmbeddingsArray(
+ XContentParser parser,
+ CheckedFunction parseEntry
+ ) throws IOException {
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
- List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingList);
+ List embeddingValues = XContentParserUtils.parseList(parser, parseEntry);
// the parser is currently sitting at an ARRAY_END so go to the next token
parser.nextToken();
@@ -107,10 +176,25 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse
return new TextEmbeddingResults.Embedding(embeddingValues);
}
- private static float parseEmbeddingList(XContentParser parser) throws IOException {
+ private static FloatValue parseEmbeddingFloatEntry(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
- return parser.floatValue();
+ return new FloatValue(parser.floatValue());
+ }
+
+ private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
+ XContentParser.Token token = parser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
+ var parsedByte = parser.shortValue();
+ checkByteBounds(parsedByte);
+
+ return new ByteValue((byte) parsedByte);
+ }
+
+ private static void checkByteBounds(short value) {
+ if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
+ throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
+ }
}
private CohereEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
index b74b03891034f..9bd01bac7bee1 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
@@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList);
- return new TextEmbeddingResults.Embedding(embeddingValues);
+ return TextEmbeddingResults.Embedding.of(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
index 4926ba3f0ef6b..a5fe46ab1de5f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
@@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse
// if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
parser.skipChildren();
- return new TextEmbeddingResults.Embedding(embeddingValues);
+ return TextEmbeddingResults.Embedding.of(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 6fef9dac6095a..37119b872bf12 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -109,6 +109,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}
+ public static String mustBeNonEmptyList(String settingName, String scope) {
+ return Strings.format("[%s] Invalid value empty list. [%s] must be a non-empty list", scope, settingName);
+ }
+
public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) {
return Strings.format(
"[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]",
@@ -191,6 +195,54 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin
return null;
}
+ public static List extractOptionalListOfEnums(
+ Map map,
+ String settingName,
+ String scope,
+ CheckedFunction converter,
+ T[] validTypes,
+ ValidationException validationException
+ ) {
+ List> listField = ServiceUtils.removeAsType(map, settingName, List.class);
+ if (listField == null) {
+ return null;
+ }
+
+ if (listField.isEmpty()) {
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope));
+ return null;
+ }
+
+ var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
+
+ List castedList = new ArrayList<>(listField.size());
+
+ for (Object listEntry : listField) {
+ if (listEntry instanceof String == false) {
+ validationException.addValidationError(
+ invalidType(
+ settingName,
+ scope,
+ listEntry.getClass().getSimpleName(),
+ listEntry.toString(),
+ String.class.getSimpleName()
+ )
+ );
+ return null;
+ }
+
+ var stringEntry = (String) listEntry;
+ try {
+ castedList.add(converter.apply(stringEntry));
+ } catch (IllegalArgumentException e) {
+ validationException.addValidationError(invalidValue(settingName, scope, stringEntry, validTypesAsStrings));
+ return null;
+ }
+ }
+
+ return castedList;
+ }
+
@SuppressWarnings("unchecked")
public static List extractOptionalListOfType(
Map map,
@@ -206,7 +258,7 @@ public static List extractOptionalListOfType(
}
if (listField.isEmpty()) {
- validationException.addValidationError(ServiceUtils.mustBeNonEmptyString(settingName, scope));
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope));
return null;
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
new file mode 100644
index 0000000000000..0e4f9ffeba466
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Locale;
+
+/**
+ * Defines the type of embedding that the cohere api should return for a request.
+ *
+ *
+ * See api docs for details.
+ *
+ */
+public enum CohereEmbeddingType implements Writeable {
+ /**
+ * Use this when you want to get back the default float embeddings. Valid for all models.
+ */
+ FLOAT,
+ /**
+ * Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
+ */
+ INT8;
+
+ public static String NAME = "cohere_embedding_type";
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+
+ public static String toLowerCase(CohereEmbeddingType type) {
+ return type.toString().toLowerCase(Locale.ROOT);
+ }
+
+ public static CohereEmbeddingType fromString(String name) {
+ return valueOf(name.trim().toUpperCase(Locale.ROOT));
+ }
+
+ public static CohereEmbeddingType fromStream(StreamInput in) throws IOException {
+ return in.readEnum(CohereEmbeddingType.class);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeEnum(this);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
index ed1295c8ee31c..816e7bc6933cd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
@@ -76,4 +76,13 @@ public DefaultSecretSettings getSecretSettings() {
public ExecutableAction accept() {
return null;
}
+
+ public CohereEmbeddingsModel overrideWith(Map taskSettings) {
+ if (taskSettings == null || taskSettings.isEmpty()) {
+ return this;
+ }
+
+ var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(taskSettings);
+ return new CohereEmbeddingsModel(this, getTaskSettings().overrideWith(requestTaskSettings));
+ }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
index c2ad8b2b8e958..849290cfaa66d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
@@ -24,7 +24,7 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfType;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfEnums;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE;
@@ -44,7 +44,7 @@
public record CohereEmbeddingsTaskSettings(
@Nullable String model,
@Nullable InputType inputType,
- @Nullable List embeddingTypes,
+ @Nullable List embeddingTypes,
@Nullable CohereTruncation truncation
) implements TaskSettings {
@@ -69,11 +69,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
InputType.values(),
validationException
);
- List embeddingTypes = extractOptionalListOfType(
+ List embeddingTypes = extractOptionalListOfEnums(
map,
EMBEDDING_TYPES,
ModelConfigurations.TASK_SETTINGS,
- String.class,
+ CohereEmbeddingType::fromString,
+ CohereEmbeddingType.values(),
validationException
);
CohereTruncation truncation = extractOptionalEnum(
@@ -93,7 +94,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
}
public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException {
- this(in.readOptionalString(), InputType.fromStream(in), in.readOptionalStringCollectionAsList(), CohereTruncation.fromStream(in));
+ this(
+ in.readOptionalString(),
+ InputType.fromStream(in),
+ in.readOptionalCollectionAsList(CohereEmbeddingType::fromStream),
+ CohereTruncation.fromStream(in)
+ );
}
@Override
@@ -132,7 +138,16 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(model);
inputType.writeTo(out);
- out.writeOptionalStringCollection(embeddingTypes);
+ out.writeOptionalCollection(embeddingTypes);
truncation.writeTo(out);
}
+
+ public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) {
+ var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model();
+ var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType();
+ var embeddingTypesToUse = requestTaskSettings.embeddingTypes() == null ? embeddingTypes : requestTaskSettings.embeddingTypes();
+ var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation();
+
+ return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse);
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
index be3f77251fd08..d6b01486ae3ec 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
@@ -14,27 +14,59 @@
import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
public class CohereEmbeddingsModelTests extends ESTestCase {
+ public void testOverrideWith_OverridesModel() {
+ var model = createModel("url", "api_key", null);
+
+ var overriddenModel = model.overrideWith(getTaskSettingsMap("model", null, null, null));
+ var expectedModel = createModel("url", "api_key", new CohereEmbeddingsTaskSettings("model", null, null, null), null, null);
+ MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+ }
+
+ public void testOverrideWith_DoesNotOverride_WhenSettingsAreEmpty() {
+ var model = createModel("url", "api_key", null);
+
+ var overriddenModel = model.overrideWith(Map.of());
+ MatcherAssert.assertThat(overriddenModel, sameInstance(model));
+ }
+
+ public void testOverrideWith_DoesNotOverride_WhenSettingsAreNull() {
+ var model = createModel("url", "api_key", null);
+
+ var overriddenModel = model.overrideWith(null);
+ MatcherAssert.assertThat(overriddenModel, sameInstance(model));
+ }
+
public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) {
- return new CohereEmbeddingsModel(
- "id",
- TaskType.TEXT_EMBEDDING,
- "service",
- new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit),
- CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
- new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
- );
+ return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null);
}
public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) {
+ return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions);
+ }
+
+ public static CohereEmbeddingsModel createModel(
+ String url,
+ String apiKey,
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable Integer tokenLimit,
+ @Nullable Integer dimensions
+ ) {
return new CohereEmbeddingsModel(
"id",
TaskType.TEXT_EMBEDDING,
"service",
new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit),
- CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ taskSettings,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
index b5e44b88b9b4c..f2ea1a8c4f3fe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -9,6 +9,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
@@ -16,7 +17,6 @@
import org.hamcrest.MatcherAssert;
import java.io.IOException;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
@@ -29,7 +29,7 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe
public static CohereEmbeddingsTaskSettings createRandom() {
var model = randomBoolean() ? randomAlphaOfLength(15) : null;
var inputType = randomBoolean() ? randomFrom(InputType.values()) : null;
- var embeddingTypes = randomBoolean() ? List.of(randomAlphaOfLength(6)) : null;
+ var embeddingTypes = randomBoolean() ? List.of(randomFrom(CohereEmbeddingType.values())) : null;
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation);
@@ -52,13 +52,20 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
CohereEmbeddingsTaskSettings.INPUT_TYPE,
InputType.INGEST.toString().toLowerCase(Locale.ROOT),
CohereEmbeddingsTaskSettings.EMBEDDING_TYPES,
- List.of("abc", "123"),
+ List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8),
CohereServiceFields.TRUNCATE,
CohereTruncation.END.toString().toLowerCase(Locale.ROOT)
)
)
),
- is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, List.of("abc", "123"), CohereTruncation.END))
+ is(
+ new CohereEmbeddingsTaskSettings(
+ "abc",
+ InputType.INGEST,
+ List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8),
+ CohereTruncation.END
+ )
+ )
);
}
@@ -77,9 +84,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
var exception = expectThrows(
ValidationException.class,
- () -> CohereEmbeddingsTaskSettings.fromMap(
- new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc", 123)))
- )
+ () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc"))))
);
MatcherAssert.assertThat(
@@ -91,6 +96,28 @@ public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
);
}
+ public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() {
+ var taskSettings = CohereEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
+ );
+
+ var overriddenTaskSettings = taskSettings.overrideWith(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
+ MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
+ }
+
+ public void testOverrideWith_UsesOverriddenSettings() {
+ var taskSettings = CohereEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
+ );
+
+ var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.START.toString()))
+ );
+
+ var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings);
+ MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.START)));
+ }
+
@Override
protected Writeable.Reader instanceReader() {
return CohereEmbeddingsTaskSettings::new;
@@ -106,7 +133,30 @@ protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettin
return null;
}
- public static Map getTaskSettingsMap() {
- return new HashMap<>(Collections.emptyMap());
+ public static Map getTaskSettingsMap(
+ @Nullable String model,
+ @Nullable InputType inputType,
+ @Nullable List embeddingTypes,
+ @Nullable CohereTruncation truncation
+ ) {
+ var map = new HashMap();
+
+ if (model != null) {
+ map.put(CohereServiceFields.MODEL, model);
+ }
+
+ if (inputType != null) {
+ map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType);
+ }
+
+ if (embeddingTypes != null) {
+ map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, embeddingTypes);
+ }
+
+ if (truncation != null) {
+ map.put(CohereServiceFields.TRUNCATE, truncation);
+ }
+
+ return map;
}
}
From 58c707d4ad4a85d5a68c97e5095c2d03f21a1222 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Thu, 18 Jan 2024 13:26:32 -0500
Subject: [PATCH 04/13] Working cohere
---
.../core/inference/results/ByteValue.java | 5 +
.../core/inference/results/FloatValue.java | 5 +
.../results/TextEmbeddingResults.java | 18 +-
.../InferenceNamedWriteablesProvider.java | 10 +
.../xpack/inference/InferencePlugin.java | 4 +-
.../external/action/ActionUtils.java | 12 +
.../action/cohere/CohereEmbeddingsAction.java | 17 +-
.../action/openai/OpenAiEmbeddingsAction.java | 14 +-
.../cohere/CohereEmbeddingsRequest.java | 9 -
.../cohere/CohereEmbeddingsRequestEntity.java | 29 +-
.../CohereEmbeddingsResponseEntity.java | 75 +-
.../cohere/CohereErrorResponseEntity.java | 1 +
.../HuggingFaceEmbeddingsResponseEntity.java | 2 +-
.../OpenAiEmbeddingsResponseEntity.java | 2 +-
.../inference/services/ServiceUtils.java | 85 --
.../services/cohere/CohereModel.java | 5 +-
.../services/cohere/CohereService.java | 174 ++++
.../embeddings/CohereEmbeddingType.java | 4 +-
.../embeddings/CohereEmbeddingsModel.java | 7 +-
.../CohereEmbeddingsTaskSettings.java | 29 +-
.../cohere/CohereActionCreatorTests.java | 152 ++++
.../cohere/CohereEmbeddingsActionTests.java | 333 +++++++
.../HuggingFaceActionCreatorTests.java | 6 +-
.../openai/OpenAiActionCreatorTests.java | 10 +-
.../openai/OpenAiEmbeddingsActionTests.java | 4 +-
.../cohere/CohereResponseHandlerTests.java | 165 ++++
.../external/openai/OpenAiClientTests.java | 8 +-
.../CohereEmbeddingsRequestEntityTests.java | 65 ++
.../cohere/CohereEmbeddingsRequestTests.java | 156 ++++
.../CohereEmbeddingsResponseEntityTests.java | 434 +++++++++
.../CohereErrorResponseEntityTests.java | 50 +
...gingFaceEmbeddingsResponseEntityTests.java | 20 +-
.../OpenAiEmbeddingsResponseEntityTests.java | 10 +-
.../inference/results/ByteValueTests.java | 35 +
.../inference/results/FloatValueTests.java | 35 +
.../results/TextEmbeddingResultsTests.java | 222 ++++-
.../services/cohere/CohereServiceTests.java | 853 ++++++++++++++++++
.../cohere/CohereTruncationTests.java | 34 +
.../embeddings/CohereEmbeddingTypeTests.java | 34 +
.../CohereEmbeddingsTaskSettingsTests.java | 33 +-
.../huggingface/HuggingFaceServiceTests.java | 4 +-
.../services/openai/OpenAiServiceTests.java | 4 +-
42 files changed, 2910 insertions(+), 264 deletions(-)
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
index 9e3dc866d304e..5017aa9be0d52 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
@@ -28,6 +28,11 @@ public ByteValue(StreamInput in) throws IOException {
value = in.readByte();
}
+ @Override
+ public String toString() {
+ return value.toString();
+ }
+
@Override
public Byte getValue() {
return value;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
index 9ead125d2d48a..5c89720e54b2e 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
@@ -28,6 +28,11 @@ public FloatValue(StreamInput in) throws IOException {
value = in.readFloat();
}
+ @Override
+ public String toString() {
+ return value.toString();
+ }
+
@Override
public Number getValue() {
return value;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
index 9fc99ba3f03f8..fadffb2782018 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
@@ -53,7 +53,10 @@ public TextEmbeddingResults(StreamInput in) throws IOException {
@SuppressWarnings("deprecation")
TextEmbeddingResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) {
this(
- legacyTextEmbeddingResults.embeddings().stream().map(embedding -> Embedding.of(embedding.values())).collect(Collectors.toList())
+ legacyTextEmbeddingResults.embeddings()
+ .stream()
+ .map(embedding -> Embedding.ofFloats(embedding.values()))
+ .collect(Collectors.toList())
);
}
@@ -105,10 +108,16 @@ public Map asMap() {
public static class Embedding implements Writeable, ToXContentObject {
public static final String EMBEDDING = "embedding";
- public static Embedding of(List values) {
+ public static Embedding ofFloats(List values) {
return new Embedding(convertFloatsToEmbeddingValues(values));
}
+ public static Embedding ofBytes(List values) {
+ List convertedValues = values.stream().map(ByteValue::new).collect(Collectors.toList());
+
+ return new Embedding(convertedValues);
+ }
+
private final List values;
public Embedding(List values) {
@@ -131,12 +140,15 @@ public List toFloats() {
return values.stream().map(value -> value.getValue().floatValue()).toList();
}
+ public List values() {
+ return values;
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
out.writeNamedWriteableCollection(values);
} else {
- // TODO do we need to check that the values are floats here?
out.writeCollection(toFloats(), StreamOutput::writeFloat);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index a6d6c3cbee4d5..02d19ba60b0e7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -20,6 +20,8 @@
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
@@ -92,6 +94,14 @@ public static List getNamedWriteables() {
new NamedWriteableRegistry.Entry(TaskSettings.class, OpenAiEmbeddingsTaskSettings.NAME, OpenAiEmbeddingsTaskSettings::new)
);
+ // Cohere
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
+ );
+
return namedWriteables;
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
index 33d71c65ed643..d5e41f20113cc 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@@ -54,6 +54,7 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
@@ -154,7 +155,8 @@ public List getInferenceServiceFactories() {
ElserMlNodeService::new,
context -> new HuggingFaceElserService(httpFactory, serviceComponents),
context -> new HuggingFaceService(httpFactory, serviceComponents),
- context -> new OpenAiService(httpFactory, serviceComponents)
+ context -> new OpenAiService(httpFactory, serviceComponents),
+ context -> new CohereService(httpFactory, serviceComponents)
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
index 856146fafcb45..4a8519934c63c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
@@ -11,9 +11,13 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
+import java.net.URI;
+
public class ActionUtils {
public static ActionListener wrapFailuresInElasticsearchException(
@@ -35,5 +39,13 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e);
}
+ public static String getErrorMessage(@Nullable URI uri, String message) {
+ if (uri != null) {
+ return Strings.format("Failed to send %s request to [%s]", message, uri);
+ }
+
+ return Strings.format("Failed to send %s request", message);
+ }
+
private ActionUtils() {}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
index cc84b06a7ba8f..62afbc190d573 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
@@ -11,7 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
@@ -25,12 +24,11 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
-import java.net.URI;
import java.util.List;
import java.util.Objects;
-import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class CohereEmbeddingsAction implements ExecutableAction {
@@ -45,7 +43,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey());
- this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri());
+ this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "Cohere embeddings");
this.sender = new RetryingHttpSender(
Objects.requireNonNull(sender),
serviceComponents.throttlerManager(),
@@ -55,20 +53,9 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Servic
);
}
- private static String getErrorMessage(@Nullable URI uri) {
- if (uri != null) {
- return format("Failed to send Cohere embeddings request to [%s]", uri.toString());
- }
-
- return "Failed to send Cohere embeddings request";
- }
-
@Override
public void execute(List input, ActionListener listener) {
try {
- // TODO only truncate if the setting is NONE?
- // var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
-
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings());
ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
index 20128f1168bb9..417935f8c920c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
@@ -9,7 +9,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -20,13 +19,12 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
-import java.net.URI;
import java.util.List;
import java.util.Objects;
-import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class OpenAiEmbeddingsAction implements ExecutableAction {
@@ -45,18 +43,10 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic
this.model.getSecretSettings().apiKey()
);
this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents));
- this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri());
+ this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "send OpenAI embeddings");
this.truncator = Objects.requireNonNull(serviceComponents.truncator());
}
- private static String getErrorMessage(@Nullable URI uri) {
- if (uri != null) {
- return format("Failed to send OpenAI embeddings request to [%s]", uri.toString());
- }
-
- return "Failed to send OpenAI embeddings request";
- }
-
@Override
public void execute(List input, ActionListener listener) {
try {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
index ba0e3d258dba9..277996140001f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
@@ -63,12 +63,7 @@ public URI getURI() {
@Override
public Request truncate() {
-
return this;
- // TODO only do this is the truncate setting is NONE?
- // var truncatedInput = truncator.truncate(truncationResult.input());
- //
- // return new CohereEmbeddingsRequest(truncator, account, truncatedInput, taskSettings);
}
@Override
@@ -76,10 +71,6 @@ public boolean[] getTruncationInfo() {
return null;
}
- public CohereEmbeddingsTaskSettings getTaskSettings() {
- return taskSettings;
- }
-
// default for testing
static URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
index 7331426b5c683..831fb07b30e24 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.request.cohere;
+import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
@@ -14,10 +15,22 @@
import java.io.IOException;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject {
+ private static final String SEARCH_DOCUMENT = "search_document";
+ private static final String SEARCH_QUERY = "search_query";
+ /**
+ * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request
+ */
+ private static final Map INPUT_TYPE_MAPPING = Map.of(
+ InputType.INGEST,
+ SEARCH_DOCUMENT,
+ InputType.SEARCH,
+ SEARCH_QUERY
+ );
private static final String TEXTS_FIELD = "texts";
static final String INPUT_TYPE_FIELD = "input_type";
@@ -37,11 +50,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
if (taskSettings.inputType() != null) {
- builder.field(INPUT_TYPE_FIELD, taskSettings.inputType());
+ builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType()));
}
- if (taskSettings.embeddingTypes() != null) {
- builder.field(EMBEDDING_TYPES_FIELD, taskSettings.embeddingTypes());
+ if (taskSettings.embeddingType() != null) {
+ builder.field(EMBEDDING_TYPES_FIELD, List.of(taskSettings.embeddingType()));
}
if (taskSettings.truncation() != null) {
@@ -51,4 +64,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}
+
+ private static String covertToString(InputType inputType) {
+ var stringValue = INPUT_TYPE_MAPPING.get(inputType);
+
+ if (stringValue == null) {
+ return SEARCH_DOCUMENT;
+ }
+
+ return stringValue;
+ }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
index 8f0eabeda3a04..8f6d6c75d8c71 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.response.cohere;
+import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.core.CheckedFunction;
@@ -40,6 +41,12 @@ public class CohereEmbeddingsResponseEntity {
toLowerCase(CohereEmbeddingType.INT8),
CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry
);
+ private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
+
+ private static String supportedEmbeddingTypes() {
+ var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
+ return String.join(", ", validTypes);
+ }
/**
* Parses the OpenAI json response.
@@ -86,6 +93,42 @@ public class CohereEmbeddingsResponseEntity {
* }
*
*
+ *
+ * Or this:
+ *
+ *
+ *
+ * {
+ * "id": "da4f9ea6-37e4-41ab-b5e1-9e2985609555",
+ * "texts": [
+ * "hello",
+ * "awesome"
+ * ],
+ * "embeddings": {
+ * "float": [
+ * [
+ * 123
+ * ],
+ * [
+ * 123
+ * ],
+ * ]
+ * },
+ * "meta": {
+ * "api_version": {
+ * "version": "1"
+ * },
+ * "warnings": [
+ * "default model on embed will be deprecated in the future, please specify a model in the request."
+ * ],
+ * "billed_units": {
+ * "input_tokens": 3
+ * }
+ * },
+ * "response_type": "embeddings_floats"
+ * }
+ *
+ *
*/
public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -96,23 +139,13 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp
XContentParser.Token token = jsonParser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
- // TODO can't really rely on the texts field being before embeddings, this will need to be a loop
- // we don't return the is_truncated result for text embeddings so we don't really need this yet
- positionParserAtTokenAfterField(jsonParser, "texts", FAILED_TO_FIND_FIELD_TEMPLATE);
- XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, jsonParser.currentToken(), jsonParser);
-
- List inputAfterTruncation = XContentParserUtils.parseList(
- jsonParser,
- CohereEmbeddingsResponseEntity::parseTruncatedTextsField
- );
-
positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- // TODO we don't need this yet, just require that the embedding type be a single string so that
- // this is always the second style
+
token = jsonParser.currentToken();
if (token == XContentParser.Token.START_OBJECT) {
return parseEmbeddingsObject(jsonParser);
} else if (token == XContentParser.Token.START_ARRAY) {
+ // if the request did not specify the embedding types then it will default to floats
List embeddingList = XContentParserUtils.parseList(
jsonParser,
parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(
@@ -132,11 +165,6 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp
}
}
- private static String parseTruncatedTextsField(XContentParser parser) throws IOException {
- XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
- return parser.text();
- }
-
private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException {
XContentParser.Token token;
@@ -157,7 +185,12 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser)
}
}
- throw new IllegalStateException("Failed to find a supported embedding type");
+ throw new IllegalStateException(
+ Strings.format(
+ "Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [%s]",
+ VALID_EMBEDDING_TYPES_STRING
+ )
+ );
}
private static TextEmbeddingResults.Embedding parseEmbeddingsArray(
@@ -165,14 +198,8 @@ private static TextEmbeddingResults.Embedding parseEmbeddingsArray(
CheckedFunction parseEntry
) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
-
List embeddingValues = XContentParserUtils.parseList(parser, parseEntry);
- // the parser is currently sitting at an ARRAY_END so go to the next token
- parser.nextToken();
- // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
- parser.skipChildren();
-
return new TextEmbeddingResults.Embedding(embeddingValues);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
index 6f947b45955be..7d1731105e2f5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntity.java
@@ -22,6 +22,7 @@ private CohereErrorResponseEntity(String errorMessage) {
this.errorMessage = errorMessage;
}
+ @Override
public String getErrorMessage() {
return errorMessage;
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
index 9bd01bac7bee1..b3706b318439b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
@@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList);
- return TextEmbeddingResults.Embedding.of(embeddingValues);
+ return TextEmbeddingResults.Embedding.ofFloats(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
index a5fe46ab1de5f..0640821d0b9e1 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
@@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse
// if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
parser.skipChildren();
- return TextEmbeddingResults.Embedding.of(embeddingValues);
+ return TextEmbeddingResults.Embedding.ofFloats(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 37119b872bf12..3e9f6e1f75a8a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -22,7 +22,6 @@
import java.net.URI;
import java.net.URISyntaxException;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
@@ -195,90 +194,6 @@ public static SimilarityMeasure extractSimilarity(Map map, Strin
return null;
}
- public static List extractOptionalListOfEnums(
- Map map,
- String settingName,
- String scope,
- CheckedFunction converter,
- T[] validTypes,
- ValidationException validationException
- ) {
- List> listField = ServiceUtils.removeAsType(map, settingName, List.class);
- if (listField == null) {
- return null;
- }
-
- if (listField.isEmpty()) {
- validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope));
- return null;
- }
-
- var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
-
- List castedList = new ArrayList<>(listField.size());
-
- for (Object listEntry : listField) {
- if (listEntry instanceof String == false) {
- validationException.addValidationError(
- invalidType(
- settingName,
- scope,
- listEntry.getClass().getSimpleName(),
- listEntry.toString(),
- String.class.getSimpleName()
- )
- );
- return null;
- }
-
- var stringEntry = (String) listEntry;
- try {
- castedList.add(converter.apply(stringEntry));
- } catch (IllegalArgumentException e) {
- validationException.addValidationError(invalidValue(settingName, scope, stringEntry, validTypesAsStrings));
- return null;
- }
- }
-
- return castedList;
- }
-
- @SuppressWarnings("unchecked")
- public static List extractOptionalListOfType(
- Map map,
- String settingName,
- String scope,
- Class type,
- ValidationException validationException
- ) {
- List> listField = ServiceUtils.removeAsType(map, settingName, List.class);
-
- if (listField == null) {
- return null;
- }
-
- if (listField.isEmpty()) {
- validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope));
- return null;
- }
-
- List castedList = new ArrayList<>(listField.size());
-
- for (Object listEntry : listField) {
- if (type.isAssignableFrom(listEntry.getClass()) == false) {
- // TODO should we just throw here like removeAsType
- validationException.addValidationError(
- invalidType(settingName, scope, listEntry.getClass().getSimpleName(), listEntry.toString(), type.getSimpleName())
- );
- return null;
- }
-
- castedList.add((T) listEntry);
- }
-
- return castedList;
- }
-
public static String extractRequiredString(
Map map,
String settingName,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
index d92ea419faef0..1b4843e441248 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java
@@ -13,6 +13,9 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
+
+import java.util.Map;
public abstract class CohereModel extends Model {
public CohereModel(ModelConfigurations configurations, ModelSecrets secrets) {
@@ -27,5 +30,5 @@ protected CohereModel(CohereModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
}
- public abstract ExecutableAction accept();
+ public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
new file mode 100644
index 0000000000000..e654316a4fbd4
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
@@ -0,0 +1,174 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
+import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+
+public class CohereService extends SenderService {
+ public static final String NAME = "cohere";
+
+ public CohereService(SetOnce factory, SetOnce serviceComponents) {
+ super(factory, serviceComponents);
+ }
+
+ @Override
+ public String name() {
+ return NAME;
+ }
+
+ @Override
+ public CohereModel parseRequestConfig(
+ String modelId,
+ TaskType taskType,
+ Map config,
+ Set platformArchitectures
+ ) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+
+ CohereModel model = createModel(
+ modelId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ serviceSettingsMap,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME)
+ );
+
+ throwIfNotEmptyMap(config, NAME);
+ throwIfNotEmptyMap(serviceSettingsMap, NAME);
+ throwIfNotEmptyMap(taskSettingsMap, NAME);
+
+ return model;
+ }
+
+ private static CohereModel createModel(
+ String modelId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secretSettings,
+ String failureMessage
+ ) {
+ return switch (taskType) {
+ case TEXT_EMBEDDING -> new CohereEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
+ default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ };
+ }
+
+ @Override
+ public CohereModel parsePersistedConfigWithSecrets(
+ String modelId,
+ TaskType taskType,
+ Map config,
+ Map secrets
+ ) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+ Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
+
+ return createModel(
+ modelId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ secretSettingsMap,
+ parsePersistedConfigErrorMsg(modelId, NAME)
+ );
+ }
+
+ @Override
+ public CohereModel parsePersistedConfig(String modelId, TaskType taskType, Map config) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+
+ return createModel(modelId, taskType, serviceSettingsMap, taskSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME));
+ }
+
+ @Override
+ public void doInfer(
+ Model model,
+ List input,
+ Map taskSettings,
+ ActionListener listener
+ ) {
+ if (model instanceof CohereModel == false) {
+ listener.onFailure(createInvalidModelException(model));
+ return;
+ }
+
+ CohereModel cohereModel = (CohereModel) model;
+ var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
+
+ var action = cohereModel.accept(actionCreator, taskSettings);
+ action.execute(input, listener);
+ }
+
+ /**
+ * For text embedding models get the embedding size and
+ * update the service settings.
+ *
+ * @param model The new model
+ * @param listener The listener
+ */
+ @Override
+ public void checkModelConfig(Model model, ActionListener listener) {
+ if (model instanceof CohereEmbeddingsModel embeddingsModel) {
+ ServiceUtils.getEmbeddingSize(
+ model,
+ this,
+ listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+ );
+ } else {
+ listener.onResponse(model);
+ }
+ }
+
+ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
+ CohereServiceSettings serviceSettings = new CohereServiceSettings(
+ model.getServiceSettings().uri(),
+ SimilarityMeasure.DOT_PRODUCT,
+ embeddingSize,
+ model.getServiceSettings().maxInputTokens()
+ );
+
+ return new CohereEmbeddingsModel(model, serviceSettings);
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
index 0e4f9ffeba466..803382bb8c947 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
@@ -47,11 +47,11 @@ public static CohereEmbeddingType fromString(String name) {
}
public static CohereEmbeddingType fromStream(StreamInput in) throws IOException {
- return in.readEnum(CohereEmbeddingType.class);
+ return in.readOptionalEnum(CohereEmbeddingType.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeEnum(this);
+ out.writeOptionalEnum(this);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
index 816e7bc6933cd..accba9149a46e 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
@@ -12,6 +12,7 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@@ -53,7 +54,7 @@ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskS
super(model, taskSettings);
}
- private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) {
+ public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) {
super(model, serviceSettings);
}
@@ -73,8 +74,8 @@ public DefaultSecretSettings getSecretSettings() {
}
@Override
- public ExecutableAction accept() {
- return null;
+ public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) {
+ return visitor.create(this, taskSettings);
}
public CohereEmbeddingsModel overrideWith(Map taskSettings) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
index 849290cfaa66d..b0da7a63bf918 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
@@ -20,11 +20,9 @@
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
-import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfEnums;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE;
@@ -38,20 +36,20 @@
*
* @param model the id of the model to use in the requests to cohere
* @param inputType Specifies the type of input you're giving to the model
- * @param embeddingTypes Specifies the types of embeddings you want to get back
+ * @param embeddingType Specifies the type of embeddings you want to get back (we only support retrieving a single type)
* @param truncation Specifies how the API will handle inputs longer than the maximum token length
*/
public record CohereEmbeddingsTaskSettings(
@Nullable String model,
@Nullable InputType inputType,
- @Nullable List embeddingTypes,
+ @Nullable CohereEmbeddingType embeddingType,
@Nullable CohereTruncation truncation
) implements TaskSettings {
public static final String NAME = "cohere_embeddings_task_settings";
- static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null);
+ public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null);
static final String INPUT_TYPE = "input_type";
- static final String EMBEDDING_TYPES = "embedding_types";
+ static final String EMBEDDING_TYPE = "embedding_type";
public static CohereEmbeddingsTaskSettings fromMap(Map map) {
if (map.isEmpty()) {
@@ -69,9 +67,9 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
InputType.values(),
validationException
);
- List embeddingTypes = extractOptionalListOfEnums(
+ CohereEmbeddingType embeddingTypes = extractOptionalEnum(
map,
- EMBEDDING_TYPES,
+ EMBEDDING_TYPE,
ModelConfigurations.TASK_SETTINGS,
CohereEmbeddingType::fromString,
CohereEmbeddingType.values(),
@@ -94,12 +92,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
}
public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException {
- this(
- in.readOptionalString(),
- InputType.fromStream(in),
- in.readOptionalCollectionAsList(CohereEmbeddingType::fromStream),
- CohereTruncation.fromStream(in)
- );
+ this(in.readOptionalString(), InputType.fromStream(in), CohereEmbeddingType.fromStream(in), CohereTruncation.fromStream(in));
}
@Override
@@ -113,8 +106,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_TYPE, inputType);
}
- if (embeddingTypes != null) {
- builder.field(EMBEDDING_TYPES, embeddingTypes);
+ if (embeddingType != null) {
+ builder.field(EMBEDDING_TYPE, embeddingType);
}
if (truncation != null) {
@@ -138,14 +131,14 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(model);
inputType.writeTo(out);
- out.writeOptionalCollection(embeddingTypes);
+ embeddingType.writeTo(out);
truncation.writeTo(out);
}
public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) {
var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model();
var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType();
- var embeddingTypesToUse = requestTaskSettings.embeddingTypes() == null ? embeddingTypes : requestTaskSettings.embeddingTypes();
+ var embeddingTypesToUse = requestTaskSettings.embeddingType() == null ? embeddingType : requestTaskSettings.embeddingType();
var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation();
return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
new file mode 100644
index 0000000000000..e1daa78454d25
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
@@ -0,0 +1,152 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests;
+import org.hamcrest.MatcherAssert;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CohereActionCreatorTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testCreate_OpenAiEmbeddingsModel() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ 0.123,
+ -0.123
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_by_type"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = CohereEmbeddingsModelTests.createModel(
+ getUrl(webServer),
+ "secret",
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ 1024,
+ 1024
+ );
+ var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
+ var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(
+ "model",
+ InputType.SEARCH,
+ CohereEmbeddingType.INT8,
+ CohereTruncation.END
+ );
+ var action = actionCreator.create(model, overriddenTaskSettings);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ MatcherAssert.assertThat(
+ webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+ equalTo(XContentType.JSON.mediaType())
+ );
+ MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(
+ requestMap,
+ is(
+ Map.of(
+ "texts",
+ List.of("abc"),
+ "model",
+ "model",
+ "input_type",
+ "search_query",
+ "embedding_types",
+ List.of("int8"),
+ "truncate",
+ "end"
+ )
+ )
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
new file mode 100644
index 0000000000000..d2a8aad19b4c8
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
@@ -0,0 +1,333 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBytes;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class CohereEmbeddingsActionTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ 0.123,
+ -0.123
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_by_type"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var action = createAction(
+ getUrl(webServer),
+ "secret",
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ sender
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ MatcherAssert.assertThat(
+ webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+ equalTo(XContentType.JSON.mediaType())
+ );
+ MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(
+ requestMap,
+ is(
+ Map.of(
+ "texts",
+ List.of("abc"),
+ "model",
+ "model",
+ "input_type",
+ "search_document",
+ "embedding_types",
+ List.of("float"),
+ "truncate",
+ "start"
+ )
+ )
+ );
+ }
+ }
+
+ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "int8": [
+ [
+ 0,
+ -1
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_by_type"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var action = createAction(
+ getUrl(webServer),
+ "secret",
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.START),
+ sender
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ MatcherAssert.assertThat(result.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 0, (byte) -1)))));
+ MatcherAssert.assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ MatcherAssert.assertThat(
+ webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+ equalTo(XContentType.JSON.mediaType())
+ );
+ MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(
+ requestMap,
+ is(
+ Map.of(
+ "texts",
+ List.of("abc"),
+ "model",
+ "model",
+ "input_type",
+ "search_document",
+ "embedding_types",
+ List.of("int8"),
+ "truncate",
+ "start"
+ )
+ )
+ );
+ }
+ }
+
+ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
+ try (var sender = mock(Sender.class)) {
+ var thrownException = expectThrows(
+ IllegalArgumentException.class,
+ () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender)
+ );
+ MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]"));
+ }
+ }
+
+ public void testExecute_ThrowsElasticsearchException() {
+ var sender = mock(Sender.class);
+ doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ MatcherAssert.assertThat(thrownException.getMessage(), is("failed"));
+ }
+
+ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
+ var sender = mock(Sender.class);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[1];
+ listener.onFailure(new IllegalStateException("failed"));
+
+ return Void.TYPE;
+ }).when(sender).send(any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer)))
+ );
+ }
+
+ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() {
+ var sender = mock(Sender.class);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[1];
+ listener.onFailure(new IllegalStateException("failed"));
+
+ return Void.TYPE;
+ }).when(sender).send(any(), any());
+
+ var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request"));
+ }
+
+ public void testExecute_ThrowsException() {
+ var sender = mock(Sender.class);
+ doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer)))
+ );
+ }
+
+ public void testExecute_ThrowsExceptionWithNullUrl() {
+ var sender = mock(Sender.class);
+ doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any());
+
+ var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(List.of("abc"), listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request"));
+ }
+
+ private CohereEmbeddingsAction createAction(String url, String apiKey, CohereEmbeddingsTaskSettings taskSettings, Sender sender) {
+ var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024);
+
+ return new CohereEmbeddingsAction(sender, model, createWithEmptySettings(threadPool));
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
index 95b69f1231e9d..c17b293c8bc35 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
@@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -331,7 +331,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(2));
{
@@ -389,7 +389,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
index 23b6f1ea2fbe3..2d7493a783c51 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
@@ -33,7 +33,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap;
@@ -100,7 +100,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@@ -169,7 +169,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(2));
{
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -252,7 +252,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(2));
{
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -320,7 +320,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
index 6bc8e2d61d579..2d5cea7bd981e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
@@ -38,7 +38,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.equalTo;
@@ -104,7 +104,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
new file mode 100644
index 0000000000000..9b7edd12492ea
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.cohere;
+
+import org.apache.http.Header;
+import org.apache.http.HeaderElement;
+import org.apache.http.HttpResponse;
+import org.apache.http.StatusLine;
+import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.hamcrest.MatcherAssert;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.core.Is.is;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class CohereResponseHandlerTests extends ESTestCase {
+ public void testCheckForFailureStatusCode_DoesNotThrowFor200() {
+ callCheckForFailureStatusCode(200);
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor503() {
+ var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503));
+ assertFalse(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Received a server error status code for request [null] status [503]")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor429() {
+ var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429));
+ assertTrue(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Received a rate limit status code. Monthly request limit")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS));
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor400() {
+ var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400));
+ assertFalse(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Received an unsuccessful status code for request [null] status [400]")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor400_TextsTooLarge() {
+ var exception = expectThrows(
+ RetryException.class,
+ () -> callCheckForFailureStatusCode(400, "invalid request: total number of texts must be at most 96 - received 100")
+ );
+ assertFalse(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Received a texts array too large response for request [null] status [400]")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor401() {
+ var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401));
+ assertFalse(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Received an authentication error status code for request [null] status [401]")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED));
+ }
+
+ public void testCheckForFailureStatusCode_ThrowsFor300() {
+ var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300));
+ assertFalse(exception.shouldRetry());
+ MatcherAssert.assertThat(
+ exception.getCause().getMessage(),
+ containsString("Unhandled redirection for request [null] status [300]")
+ );
+ MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES));
+ }
+
+ public void testBuildRateLimitErrorMessage() {
+ var statusLine = mock(StatusLine.class);
+ when(statusLine.getStatusCode()).thenReturn(429);
+ var response = mock(HttpResponse.class);
+ when(response.getStatusLine()).thenReturn(statusLine);
+ var httpResult = new HttpResult(response, new byte[] {});
+
+ when(response.getFirstHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT)).thenReturn(
+ new BasicHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT, "3000")
+ );
+ when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE)).thenReturn(
+ new BasicHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE, "2999")
+ );
+ when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING)).thenReturn(
+ new BasicHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING, "12")
+ );
+
+ var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult);
+ MatcherAssert.assertThat(
+ error,
+ containsString("Monthly request limit [3000], permitted requests per minute [2999], remaining requests [12]")
+ );
+ }
+
+ public void testBuildRateLimitErrorMessage_FillsWithUnknown_WhenUnableToFindHeader() {
+ var statusLine = mock(StatusLine.class);
+ when(statusLine.getStatusCode()).thenReturn(429);
+ var response = mock(HttpResponse.class);
+ when(response.getStatusLine()).thenReturn(statusLine);
+ var httpResult = new HttpResult(response, new byte[] {});
+
+ var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult);
+ MatcherAssert.assertThat(
+ error,
+ containsString("Monthly request limit [unknown], permitted requests per minute [unknown], remaining requests [unknown]")
+ );
+ }
+
+ private static void callCheckForFailureStatusCode(int statusCode) {
+ callCheckForFailureStatusCode(statusCode, null);
+ }
+
+ private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage) {
+ var statusLine = mock(StatusLine.class);
+ when(statusLine.getStatusCode()).thenReturn(statusCode);
+
+ var httpResponse = mock(HttpResponse.class);
+ when(httpResponse.getStatusLine()).thenReturn(statusLine);
+ var header = mock(Header.class);
+ when(header.getElements()).thenReturn(new HeaderElement[] {});
+ when(httpResponse.getFirstHeader(anyString())).thenReturn(header);
+
+ String responseJson = Strings.format("""
+ {
+ "message": "%s"
+ }
+ """, errorMessage);
+
+ var httpRequest = mock(HttpRequestBase.class);
+ var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8));
+ var handler = new CohereResponseHandler("", (request, result) -> null);
+
+ handler.checkForFailureStatusCode(httpRequest, httpResult);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
index bb9612f01d8ff..0c4b354e7171f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
@@ -40,7 +40,7 @@
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequestTests.createRequest;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -104,7 +104,7 @@ public void testSend_SuccessfulResponse() throws IOException, URISyntaxException
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -155,7 +155,7 @@ public void testSend_SuccessfulResponse_WithoutUser() throws IOException, URISyn
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -205,7 +205,7 @@ public void testSend_SuccessfulResponse_WithoutOrganization() throws IOException
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
new file mode 100644
index 0000000000000..01a2290446529
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class CohereEmbeddingsRequestEntityTests extends ESTestCase {
+ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+ var entity = new CohereEmbeddingsRequestEntity(
+ List.of("abc"),
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)
+ );
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}"""));
+ }
+
+ public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throws IOException {
+ var entity = new CohereEmbeddingsRequestEntity(
+ List.of("abc"),
+ new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE)
+ );
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
+ }
+
+ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
+ var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"texts":["abc"]}"""));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
new file mode 100644
index 0000000000000..903f6a9500831
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
+import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class CohereEmbeddingsRequestTests extends ESTestCase {
+ public void testCreateRequest_UrlDefined() throws URISyntaxException, IOException {
+ var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
+
+ var httpRequest = request.createRequest();
+ MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest;
+
+ MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"))));
+ }
+
+ public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOException {
+ var request = createRequest(
+ "url",
+ "secret",
+ List.of("abc"),
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)
+ );
+
+ var httpRequest = request.createRequest();
+ MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest;
+
+ MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(
+ requestMap,
+ is(
+ Map.of(
+ "texts",
+ List.of("abc"),
+ "model",
+ "model",
+ "input_type",
+ "search_document",
+ "embedding_types",
+ List.of("float"),
+ "truncate",
+ "start"
+ )
+ )
+ );
+ }
+
+ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() throws URISyntaxException, IOException {
+ var request = createRequest(
+ "url",
+ "secret",
+ List.of("abc"),
+ new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.END)
+ );
+
+ var httpRequest = request.createRequest();
+ MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest;
+
+ MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(
+ requestMap,
+ is(
+ Map.of(
+ "texts",
+ List.of("abc"),
+ "model",
+ "model",
+ "input_type",
+ "search_query",
+ "embedding_types",
+ List.of("int8"),
+ "truncate",
+ "end"
+ )
+ )
+ );
+ }
+
+ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOException {
+ var request = createRequest(
+ "url",
+ "secret",
+ List.of("abc"),
+ new CohereEmbeddingsTaskSettings(null, null, null, CohereTruncation.NONE)
+ );
+
+ var httpRequest = request.createRequest();
+ MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest;
+
+ MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none")));
+ }
+
+ public static CohereEmbeddingsRequest createRequest(
+ @Nullable String url,
+ String apiKey,
+ List input,
+ CohereEmbeddingsTaskSettings taskSettings
+ ) throws URISyntaxException {
+ var uri = url == null ? null : new URI(url);
+
+ var account = new CohereAccount(uri, new SecureString(apiKey.toCharArray()));
+ return new CohereEmbeddingsRequest(account, input, taskSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
new file mode 100644
index 0000000000000..1d7af8577a722
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
@@ -0,0 +1,434 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
+ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ],
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ );
+ }
+
+ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ );
+ }
+
+ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ],
+ "int8": [
+ [
+ -1,
+ 0
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ );
+ }
+
+ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFirst() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "invalid_type": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ],
+ "int8": [
+ [
+ -1,
+ 0
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) -1, (byte) 0))))
+ );
+ }
+
+ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ],
+ [
+ -0.123,
+ 0.123
+ ]
+ ],
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(
+ List.of(
+ TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F))
+ )
+ )
+ );
+ }
+
+ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ],
+ [
+ -0.123,
+ 0.123
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(
+ List.of(
+ TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F))
+ )
+ )
+ );
+ }
+
+ public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings_not_here": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ],
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ var thrownException = expectThrows(
+ IllegalStateException.class,
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Failed to find required field [embeddings] in Cohere embeddings response")
+ );
+ }
+
+ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Negative() {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "int8": [
+ [
+ -129,
+ 127
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ var thrownException = expectThrows(
+ IllegalArgumentException.class,
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte"));
+ }
+
+ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Positive() {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "int8": [
+ [
+ -128,
+ 128
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ var thrownException = expectThrows(
+ IllegalArgumentException.class,
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
+ }
+
+ public void testFromResponse_FailsToFindAValidEmbeddingType() {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "invalid_type": [
+ [
+ -0.0018434525,
+ 0.01777649
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ var thrownException = expectThrows(
+ IllegalStateException.class,
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]")
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java
new file mode 100644
index 0000000000000..a2b1c26b2b3d5
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereErrorResponseEntityTests.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.hamcrest.MatcherAssert;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CohereErrorResponseEntityTests extends ESTestCase {
+ public void testFromResponse() {
+ String responseJson = """
+ {
+ "message": "invalid request: total number of texts must be at most 96 - received 97"
+ }
+ """;
+
+ CohereErrorResponseEntity errorMessage = CohereErrorResponseEntity.fromResponse(
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+ assertNotNull(errorMessage);
+ MatcherAssert.assertThat(
+ errorMessage.getErrorMessage(),
+ is("invalid request: total number of texts must be at most 96 - received 97")
+ );
+ }
+
+ public void testFromResponse_noMessage() {
+ String responseJson = """
+ {
+ "error": "abc"
+ }
+ """;
+
+ CohereErrorResponseEntity errorMessage = CohereErrorResponseEntity.fromResponse(
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+ assertNull(errorMessage);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
index 2b6e11fdfafa7..d54bcd91c9eda 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
@@ -37,7 +37,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException {
@@ -57,7 +57,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws IOException {
@@ -83,8 +83,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
- new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
)
)
);
@@ -115,8 +115,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
- new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
)
)
);
@@ -254,7 +254,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException {
@@ -273,7 +273,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException {
@@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException {
@@ -309,7 +309,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
index 010e990a3ce80..3e8b50591e3b6 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
@@ -49,7 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
- new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
)
)
);
@@ -261,7 +261,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
@@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
new file mode 100644
index 0000000000000..b89c3c9ab0ec9
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.ByteValue;
+
+import java.io.IOException;
+
+public class ByteValueTests extends AbstractWireSerializingTestCase {
+ public static ByteValue createRandom() {
+ return new ByteValue(randomByte());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return ByteValue::new;
+ }
+
+ @Override
+ protected ByteValue createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected ByteValue mutateInstance(ByteValue instance) throws IOException {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
new file mode 100644
index 0000000000000..1d37c0d3ddee1
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.FloatValue;
+
+import java.io.IOException;
+
+public class FloatValueTests extends AbstractWireSerializingTestCase {
+ public static FloatValue createRandom() {
+ return new FloatValue(randomFloat());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return FloatValue::new;
+ }
+
+ @Override
+ protected FloatValue createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected FloatValue mutateInstance(FloatValue instance) throws IOException {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
index 09d9894d98853..1bb237b87234b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
@@ -7,31 +7,80 @@
package org.elasticsearch.xpack.inference.results;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.ByteValue;
+import org.elasticsearch.xpack.core.inference.results.FloatValue;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
+import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
+import static org.elasticsearch.TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase {
+public class TextEmbeddingResultsTests extends AbstractBWCWireSerializationTestCase {
+
+ private enum EmbeddingType {
+ FLOAT,
+ BYTE
+ }
+
+ private static Map> EMBEDDING_TYPE_BUILDERS = Map.of(
+ EmbeddingType.FLOAT,
+ TextEmbeddingResultsTests::createRandomFloatEmbedding,
+ EmbeddingType.BYTE,
+ TextEmbeddingResultsTests::createRandomByteEmbedding
+ );
+
public static TextEmbeddingResults createRandomResults() {
+ var embeddingType = randomFrom(EmbeddingType.values());
+ var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType);
+
+ if (createFunction == null) {
+ createFunction = TextEmbeddingResultsTests::createRandomFloatEmbedding;
+ }
+
+ return createRandomResults(createFunction);
+ }
+
+ private static TextEmbeddingResults createRandomResults(Supplier creator) {
int embeddings = randomIntBetween(1, 10);
List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
- embeddingResults.add(createRandomEmbedding());
+ embeddingResults.add(creator.get());
}
return new TextEmbeddingResults(embeddingResults);
}
- private static TextEmbeddingResults.Embedding createRandomEmbedding() {
+ private static TextEmbeddingResults.Embedding createRandomEmbedding(boolean makeFloatEmbedding) {
+ return makeFloatEmbedding ? createRandomFloatEmbedding() : createRandomByteEmbedding();
+ }
+
+ private static TextEmbeddingResults.Embedding createRandomByteEmbedding() {
+ int columns = randomIntBetween(1, 10);
+ List bytes = new ArrayList<>(columns);
+
+ for (int i = 0; i < columns; i++) {
+ bytes.add(randomByte());
+ }
+
+ return TextEmbeddingResults.Embedding.ofBytes(bytes);
+ }
+
+ private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() {
int columns = randomIntBetween(1, 10);
List floats = new ArrayList<>(columns);
@@ -39,19 +88,34 @@ private static TextEmbeddingResults.Embedding createRandomEmbedding() {
floats.add(randomFloat());
}
- return new TextEmbeddingResults.Embedding(floats);
+ return TextEmbeddingResults.Embedding.ofFloats(floats);
}
- public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(List.of(0.1F))));
+ public static Map buildExpectationFloats(List> embeddings) {
+ return Map.of(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ embeddings.stream()
+ .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(FloatValue::new).toList()))
+ .toList()
+ );
+ }
- assertThat(
- entity.asMap(),
- is(Map.of(TextEmbeddingResults.TEXT_EMBEDDING, List.of(Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)))))
+ public static Map buildExpectationBytes(List> embeddings) {
+ return Map.of(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ embeddings.stream()
+ .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(ByteValue::new).toList()))
+ .toList()
);
+ }
+
+ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() {
+ var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F))));
+
+ MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F)))));
String xContentResult = Strings.toString(entity, true, true);
- assertThat(xContentResult, is("""
+ MatcherAssert.assertThat(xContentResult, is("""
{
"text_embedding" : [
{
@@ -63,27 +127,33 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}"""));
}
- public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingResults(
- List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)), new TextEmbeddingResults.Embedding(List.of(0.2F)))
+ public void testToXContent_CreatesTheRightFormatForASingleEmbedding_ForBytes() {
+ var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12))));
- );
+ MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12)))));
- assertThat(
- entity.asMap(),
- is(
- Map.of(
- TextEmbeddingResults.TEXT_EMBEDDING,
- List.of(
- Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)),
- Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.2F))
- )
- )
- )
+ String xContentResult = Strings.toString(entity, true, true);
+ MatcherAssert.assertThat(xContentResult, is("""
+ {
+ "text_embedding" : [
+ {
+ "embedding" : [
+ 12
+ ]
+ }
+ ]
+ }"""));
+ }
+
+ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() {
+ var entity = new TextEmbeddingResults(
+ List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)), TextEmbeddingResults.Embedding.ofFloats(List.of(0.2F)))
);
+ MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F), List.of(0.2F)))));
+
String xContentResult = Strings.toString(entity, true, true);
- assertThat(xContentResult, is("""
+ MatcherAssert.assertThat(xContentResult, is("""
{
"text_embedding" : [
{
@@ -100,12 +170,40 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}"""));
}
+ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings_ForBytes() {
+ var entity = new TextEmbeddingResults(
+ List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)), TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 34)))
+ );
+
+ MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12), List.of((byte) 34)))));
+
+ String xContentResult = Strings.toString(entity, true, true);
+ MatcherAssert.assertThat(xContentResult, is("""
+ {
+ "text_embedding" : [
+ {
+ "embedding" : [
+ 12
+ ]
+ },
+ {
+ "embedding" : [
+ 34
+ ]
+ }
+ ]
+ }"""));
+ }
+
public void testTransformToCoordinationFormat() {
var results = new TextEmbeddingResults(
- List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F)))
+ List.of(
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F, 0.2F)),
+ TextEmbeddingResults.Embedding.ofFloats(List.of(0.3F, 0.4F))
+ )
).transformToCoordinationFormat();
- assertThat(
+ MatcherAssert.assertThat(
results,
is(
List.of(
@@ -124,6 +222,49 @@ public void testTransformToCoordinationFormat() {
);
}
+ public void testTransformToCoordinationFormat_FromBytes() {
+ var results = new TextEmbeddingResults(
+ List.of(
+ TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12, (byte) 34)),
+ TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 56, (byte) -78))
+ )
+ ).transformToCoordinationFormat();
+
+ MatcherAssert.assertThat(
+ results,
+ is(
+ List.of(
+ new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ new double[] { 12F, 34F },
+ false
+ ),
+ new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ new double[] { 56F, -78F },
+ false
+ )
+ )
+ )
+ );
+ }
+
+ public void testSerializesToFloats_WhenVersionIsPriorToByteSupport() throws IOException {
+ var instance = createRandomResults(TextEmbeddingResultsTests::createRandomByteEmbedding);
+ var modifiedForOlderVersion = mutateInstanceForVersion(instance, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
+
+ var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
+ assertOnBWCObject(copy, modifiedForOlderVersion, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ List entries = new ArrayList<>();
+ entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
+ return new NamedWriteableRegistry(entries);
+ }
+
@Override
protected Writeable.Reader instanceReader() {
return TextEmbeddingResults::new;
@@ -143,15 +284,26 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr
return new TextEmbeddingResults(instance.embeddings().subList(0, end));
} else {
List embeddings = new ArrayList<>(instance.embeddings());
- embeddings.add(createRandomEmbedding());
+ embeddings.add(createRandomEmbedding(randomBoolean()));
return new TextEmbeddingResults(embeddings);
}
}
- public static Map buildExpectation(List> embeddings) {
- return Map.of(
- TextEmbeddingResults.TEXT_EMBEDDING,
- embeddings.stream().map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding)).toList()
- );
+ @Override
+ protected TextEmbeddingResults mutateInstanceForVersion(TextEmbeddingResults instance, TransportVersion version) {
+ if (version.before(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
+ return convertToFloatEmbeddings(instance);
+ }
+
+ return instance;
+ }
+
+ public TextEmbeddingResults convertToFloatEmbeddings(TextEmbeddingResults results) {
+ var floatEmbeddings = results.embeddings()
+ .stream()
+ .map(embedding -> TextEmbeddingResults.Embedding.ofFloats(embedding.toFloats()))
+ .toList();
+
+ return new TextEmbeddingResults(floatEmbeddings);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
new file mode 100644
index 0000000000000..2061ac646c56b
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -0,0 +1,853 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
+import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests.getServiceSettingsMap;
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+public class CohereServiceTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var model = service.parseRequestConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ getRequestConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMap("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START))
+ );
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parseRequestConfig(
+ "id",
+ TaskType.SPARSE_EMBEDDING,
+ getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")),
+ Set.of()
+ )
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("The [cohere] service does not support task type [sparse_embedding]")
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret"));
+ config.put("extra_key", "value");
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of())
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service")
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var serviceSettings = getServiceSettingsMap("url");
+ serviceSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ serviceSettings,
+ getTaskSettingsMap("model", null, null, null),
+ getSecretSettingsMap("secret")
+ );
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of())
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service")
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var taskSettingsMap = getTaskSettingsMap("model", null, null, null);
+ taskSettingsMap.put("extra_key", "value");
+
+ var config = getRequestConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret"));
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of())
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service")
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var secretSettingsMap = getSecretSettingsMap("secret");
+ secretSettingsMap.put("extra_key", "value");
+
+ var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap);
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of())
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service")
+ );
+ }
+ }
+
+ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var model = service.parseRequestConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ getRequestConfigMap(getServiceSettingsMap(null), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")),
+ Set.of()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ assertNull(embeddingsModel.getServiceSettings().uri());
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMap("model", null, null, null),
+ getSecretSettingsMap("secret")
+ );
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null)));
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMapEmpty(),
+ getSecretSettingsMap("secret")
+ );
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.SPARSE_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ )
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again")
+ );
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap(null),
+ getTaskSettingsMap(null, InputType.INGEST, null, null),
+ getSecretSettingsMap("secret")
+ );
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ assertNull(embeddingsModel.getServiceSettings().uri());
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings(null, InputType.INGEST, null, null))
+ );
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMap("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE),
+ getSecretSettingsMap("secret")
+ );
+ persistedConfig.config().put("extra_key", "value");
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE))
+ );
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var secretSettingsMap = getSecretSettingsMap("secret");
+ secretSettingsMap.put("extra_key", "value");
+
+ var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap);
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMap("model", null, null, null),
+ getSecretSettingsMap("secret")
+ );
+ persistedConfig.secrets.put("extra_key", "value");
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null)));
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var serviceSettingsMap = getServiceSettingsMap("url");
+ serviceSettingsMap.put("extra_key", "value");
+
+ var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret"));
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var taskSettingsMap = getTaskSettingsMap("model", InputType.SEARCH, null, null);
+ taskSettingsMap.put("extra_key", "value");
+
+ var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret"));
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ );
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, null, null))
+ );
+ MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }
+ }
+
+ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap("url"),
+ getTaskSettingsMap("model", null, null, CohereTruncation.NONE)
+ );
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.NONE))
+ );
+ assertNull(embeddingsModel.getSecretSettings());
+ }
+ }
+
+ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty());
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again")
+ );
+ }
+ }
+
+ public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(
+ getServiceSettingsMap(null),
+ getTaskSettingsMap("model", null, CohereEmbeddingType.FLOAT, null)
+ );
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ assertNull(embeddingsModel.getServiceSettings().uri());
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", null, CohereEmbeddingType.FLOAT, null))
+ );
+ assertNull(embeddingsModel.getSecretSettings());
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty());
+ persistedConfig.config().put("extra_key", "value");
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
+ assertNull(embeddingsModel.getSecretSettings());
+ }
+ }
+
+ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var serviceSettingsMap = getServiceSettingsMap("url");
+ serviceSettingsMap.put("extra_key", "value");
+
+ var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(null, InputType.SEARCH, null, null));
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings(null, InputType.SEARCH, null, null))
+ );
+ assertNull(embeddingsModel.getSecretSettings());
+ }
+ }
+
+ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
+ try (
+ var service = new CohereService(
+ new SetOnce<>(mock(HttpRequestSenderFactory.class)),
+ new SetOnce<>(createWithEmptySettings(threadPool))
+ )
+ ) {
+ var taskSettingsMap = getTaskSettingsMap("model", InputType.INGEST, null, null);
+ taskSettingsMap.put("extra_key", "value");
+
+ var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap);
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+ MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
+
+ var embeddingsModel = (CohereEmbeddingsModel) model;
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(
+ embeddingsModel.getTaskSettings(),
+ is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null))
+ );
+ assertNull(embeddingsModel.getSecretSettings());
+ }
+ }
+
+ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException {
+ var sender = mock(Sender.class);
+
+ var factory = mock(HttpRequestSenderFactory.class);
+ when(factory.createSender(anyString())).thenReturn(sender);
+
+ var mockModel = getInvalidModel("model_id", "service_name");
+
+ try (var service = new CohereService(new SetOnce<>(factory), new SetOnce<>(createWithEmptySettings(threadPool)))) {
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(mockModel, List.of(""), new HashMap<>(), listener);
+
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
+ );
+
+ verify(factory, times(1)).createSender(anyString());
+ verify(sender, times(1)).start();
+ }
+
+ verify(sender, times(1)).close();
+ verifyNoMoreInteractions(factory);
+ verifyNoMoreInteractions(sender);
+ }
+
+ public void testInfer_SendsRequest() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) {
+
+ String responseJson = """
+ {
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ 0.123,
+ -0.123
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_by_type"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = CohereEmbeddingsModelTests.createModel(
+ getUrl(webServer),
+ "secret",
+ new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null),
+ 1024,
+ 1024
+ );
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(model, List.of("abc"), new HashMap<>(), listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ MatcherAssert.assertThat(
+ webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+ equalTo(XContentType.JSON.mediaType())
+ );
+ MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document")));
+ }
+ }
+
+ public void testCheckModelConfig_UpdatesDimensions() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) {
+
+ String responseJson = """
+ {
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "float": [
+ [
+ 0.123,
+ -0.123
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_by_type"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = CohereEmbeddingsModelTests.createModel(
+ getUrl(webServer),
+ "secret",
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ 10,
+ 1
+ );
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.checkModelConfig(model, listener);
+ var result = listener.actionGet(TIMEOUT);
+
+ MatcherAssert.assertThat(
+ result,
+ // the dimension is set to 2 because there are 2 embeddings returned from the mock server
+ is(CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2))
+ );
+ }
+ }
+
+ public void testInfer_UnauthorisedResponse() throws IOException {
+ var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
+
+ try (var service = new CohereService(new SetOnce<>(senderFactory), new SetOnce<>(createWithEmptySettings(threadPool)))) {
+
+ String responseJson = """
+ {
+ "message": "invalid api token"
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
+
+ var model = CohereEmbeddingsModelTests.createModel(
+ getUrl(webServer),
+ "secret",
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ 1024,
+ 1024
+ );
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(model, List.of("abc"), new HashMap<>(), listener);
+
+ var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+ MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request"));
+ MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [invalid api token]"));
+ MatcherAssert.assertThat(webServer.requests(), hasSize(1));
+ }
+ }
+
+ private Map getRequestConfigMap(
+ Map serviceSettings,
+ Map taskSettings,
+ Map secretSettings
+ ) {
+ var builtServiceSettings = new HashMap<>();
+ builtServiceSettings.putAll(serviceSettings);
+ builtServiceSettings.putAll(secretSettings);
+
+ return new HashMap<>(
+ Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
+ );
+ }
+
+ private PeristedConfig getPersistedConfigMap(
+ Map serviceSettings,
+ Map taskSettings,
+ Map secretSettings
+ ) {
+
+ return new PeristedConfig(
+ new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)),
+ new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings))
+ );
+ }
+
+ private PeristedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) {
+
+ return new PeristedConfig(
+ new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)),
+ null
+ );
+ }
+
+ private record PeristedConfig(Map config, Map secrets) {}
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
new file mode 100644
index 0000000000000..600ddb54eddd3
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.io.IOException;
+
+public class CohereTruncationTests extends AbstractWireSerializingTestCase {
+ public static CohereTruncation createRandom() {
+ return randomFrom(CohereTruncation.values());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CohereTruncation::fromStream;
+ }
+
+ @Override
+ protected CohereTruncation createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CohereTruncation mutateInstance(CohereTruncation instance) throws IOException {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
new file mode 100644
index 0000000000000..8b907746b0779
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.io.IOException;
+
+public class CohereEmbeddingTypeTests extends AbstractWireSerializingTestCase {
+ public static CohereEmbeddingType createRandom() {
+ return randomFrom(CohereEmbeddingType.values());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CohereEmbeddingType::fromStream;
+ }
+
+ @Override
+ protected CohereEmbeddingType createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CohereEmbeddingType mutateInstance(CohereEmbeddingType instance) throws IOException {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
index f2ea1a8c4f3fe..84e65ed87e372 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -29,10 +29,10 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe
public static CohereEmbeddingsTaskSettings createRandom() {
var model = randomBoolean() ? randomAlphaOfLength(15) : null;
var inputType = randomBoolean() ? randomFrom(InputType.values()) : null;
- var embeddingTypes = randomBoolean() ? List.of(randomFrom(CohereEmbeddingType.values())) : null;
+ var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null;
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
- return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation);
+ return new CohereEmbeddingsTaskSettings(model, inputType, embeddingType, truncation);
}
public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
@@ -51,21 +51,14 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
"abc",
CohereEmbeddingsTaskSettings.INPUT_TYPE,
InputType.INGEST.toString().toLowerCase(Locale.ROOT),
- CohereEmbeddingsTaskSettings.EMBEDDING_TYPES,
- List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8),
+ CohereEmbeddingsTaskSettings.EMBEDDING_TYPE,
+ CohereEmbeddingType.INT8,
CohereServiceFields.TRUNCATE,
CohereTruncation.END.toString().toLowerCase(Locale.ROOT)
)
)
),
- is(
- new CohereEmbeddingsTaskSettings(
- "abc",
- InputType.INGEST,
- List.of(CohereEmbeddingType.FLOAT, CohereEmbeddingType.INT8),
- CohereTruncation.END
- )
- )
+ is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.END))
);
}
@@ -84,7 +77,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
var exception = expectThrows(
ValidationException.class,
- () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, List.of("abc"))))
+ () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc"))))
);
MatcherAssert.assertThat(
@@ -133,10 +126,14 @@ protected CohereEmbeddingsTaskSettings mutateInstance(CohereEmbeddingsTaskSettin
return null;
}
+ public static Map getTaskSettingsMapEmpty() {
+ return new HashMap<>();
+ }
+
public static Map getTaskSettingsMap(
@Nullable String model,
@Nullable InputType inputType,
- @Nullable List embeddingTypes,
+ @Nullable CohereEmbeddingType embeddingType,
@Nullable CohereTruncation truncation
) {
var map = new HashMap();
@@ -146,15 +143,15 @@ public static Map getTaskSettingsMap(
}
if (inputType != null) {
- map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType);
+ map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
}
- if (embeddingTypes != null) {
- map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPES, embeddingTypes);
+ if (embeddingType != null) {
+ map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, embeddingType.toString());
}
if (truncation != null) {
- map.put(CohereServiceFields.TRUNCATE, truncation);
+ map.put(CohereServiceFields.TRUNCATE, truncation.toString());
}
return map;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
index a76cce41b4fe4..aac50b8645993 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
@@ -47,7 +47,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@@ -496,7 +496,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(-0.0123F, 0.0123F)))));
+ assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(-0.0123F, 0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 394286ee5287b..ab4ee881a224e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -46,7 +46,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceSettingsTests.getServiceSettingsMap;
@@ -717,7 +717,7 @@ public void testInfer_SendsRequest() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
From 153785af8e27bde3f147508b23040f5ac3a119f9 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Thu, 18 Jan 2024 14:23:08 -0500
Subject: [PATCH 05/13] Fixing tests
---
.../elasticsearch/inference/InputType.java | 16 +--------
.../inference/action/InferenceAction.java | 4 +--
.../action/openai/OpenAiEmbeddingsAction.java | 2 +-
.../CohereEmbeddingsResponseEntity.java | 2 ++
.../services/cohere/CohereTruncation.java | 18 +---------
.../embeddings/CohereEmbeddingType.java | 18 +---------
.../CohereEmbeddingsTaskSettings.java | 13 ++++---
.../action/InferenceActionResponseTests.java | 6 ++--
.../results/TextEmbeddingResultsTests.java | 10 +++---
.../cohere/CohereTruncationTests.java | 34 -------------------
.../embeddings/CohereEmbeddingTypeTests.java | 34 -------------------
.../CohereEmbeddingsTaskSettingsTests.java | 15 ++++----
12 files changed, 31 insertions(+), 141 deletions(-)
delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java
index b5ba3f0d1b506..ffc67995c1dda 100644
--- a/server/src/main/java/org/elasticsearch/inference/InputType.java
+++ b/server/src/main/java/org/elasticsearch/inference/InputType.java
@@ -8,17 +8,12 @@
package org.elasticsearch.inference;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-
-import java.io.IOException;
import java.util.Locale;
/**
* Defines the type of request, whether the request is to ingest a document or search for a document.
*/
-public enum InputType implements Writeable {
+public enum InputType {
INGEST,
SEARCH;
@@ -32,13 +27,4 @@ public String toString() {
public static InputType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
-
- public static InputType fromStream(StreamInput in) throws IOException {
- return in.readOptionalEnum(InputType.class);
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeOptionalEnum(this);
- }
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index 732bc3d66bedc..30375e36a0e1d 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -88,7 +88,7 @@ public Request(StreamInput in) throws IOException {
}
this.taskSettings = in.readGenericMap();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
- this.inputType = InputType.fromStream(in);
+ this.inputType = in.readEnum(InputType.class);
} else {
this.inputType = InputType.INGEST;
}
@@ -141,7 +141,7 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeGenericMap(taskSettings);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
- inputType.writeTo(out);
+ out.writeEnum(inputType);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
index 417935f8c920c..3877157188832 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
@@ -43,7 +43,7 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic
this.model.getSecretSettings().apiKey()
);
this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents));
- this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "send OpenAI embeddings");
+ this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings");
this.truncator = Objects.requireNonNull(serviceComponents.truncator());
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
index 8f6d6c75d8c71..b9830b72de010 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
@@ -24,6 +24,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import java.io.IOException;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -45,6 +46,7 @@ public class CohereEmbeddingsResponseEntity {
private static String supportedEmbeddingTypes() {
var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
+ Arrays.sort(validTypes);
return String.join(", ", validTypes);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
index 3fec21a9e03b0..e7c9a0247bb1a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java
@@ -7,11 +7,6 @@
package org.elasticsearch.xpack.inference.services.cohere;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-
-import java.io.IOException;
import java.util.Locale;
/**
@@ -22,7 +17,7 @@
* See api docs for details.
*
*/
-public enum CohereTruncation implements Writeable {
+public enum CohereTruncation {
/**
* When the input exceeds the maximum input token length an error will be returned.
*/
@@ -36,8 +31,6 @@ public enum CohereTruncation implements Writeable {
*/
END;
- public static String NAME = "cohere_truncate";
-
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
@@ -46,13 +39,4 @@ public String toString() {
public static CohereTruncation fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
-
- public static CohereTruncation fromStream(StreamInput in) throws IOException {
- return in.readOptionalEnum(CohereTruncation.class);
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeOptionalEnum(this);
- }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
index 803382bb8c947..82d57cfb92381 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
@@ -7,11 +7,6 @@
package org.elasticsearch.xpack.inference.services.cohere.embeddings;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-
-import java.io.IOException;
import java.util.Locale;
/**
@@ -21,7 +16,7 @@
* See api docs for details.
*
*/
-public enum CohereEmbeddingType implements Writeable {
+public enum CohereEmbeddingType {
/**
* Use this when you want to get back the default float embeddings. Valid for all models.
*/
@@ -31,8 +26,6 @@ public enum CohereEmbeddingType implements Writeable {
*/
INT8;
- public static String NAME = "cohere_embedding_type";
-
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
@@ -45,13 +38,4 @@ public static String toLowerCase(CohereEmbeddingType type) {
public static CohereEmbeddingType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
-
- public static CohereEmbeddingType fromStream(StreamInput in) throws IOException {
- return in.readOptionalEnum(CohereEmbeddingType.class);
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeOptionalEnum(this);
- }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
index b0da7a63bf918..33bd4095c47cd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
@@ -92,7 +92,12 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
}
public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException {
- this(in.readOptionalString(), InputType.fromStream(in), CohereEmbeddingType.fromStream(in), CohereTruncation.fromStream(in));
+ this(
+ in.readOptionalString(),
+ in.readOptionalEnum(InputType.class),
+ in.readOptionalEnum(CohereEmbeddingType.class),
+ in.readOptionalEnum(CohereTruncation.class)
+ );
}
@Override
@@ -130,9 +135,9 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(model);
- inputType.writeTo(out);
- embeddingType.writeTo(out);
- truncation.writeTo(out);
+ out.writeOptionalEnum(inputType);
+ out.writeOptionalEnum(embeddingType);
+ out.writeOptionalEnum(truncation);
}
public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
index 759411cec1212..622cf51798609 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
@@ -46,7 +46,7 @@ protected Writeable.Reader instanceReader() {
@Override
protected InferenceAction.Response createTestInstance() {
var result = switch (randomIntBetween(0, 2)) {
- case 0 -> TextEmbeddingResultsTests.createRandomResults();
+ case 0 -> TextEmbeddingResultsTests.createRandomFloatResults();
case 1 -> LegacyTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
default -> SparseEmbeddingResultsTests.createRandomResults();
};
@@ -90,7 +90,7 @@ public void testSerializesOpenAiAddedVersion_UsingSparseEmbeddingResult() throws
}
public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException {
- var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
+ var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), INFERENCE_MULTIPLE_INPUTS);
assertOnBWCObject(copy, instance, INFERENCE_MULTIPLE_INPUTS);
@@ -106,7 +106,7 @@ public void testSerializesMultipleInputsVersion_UsingSparseEmbeddingResult() thr
// Technically we should never see a text embedding result in the transport version of this test because support
// for it wasn't added until openai
public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException {
- var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
+ var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED);
assertOnBWCObject(copy, instance, ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
index 1bb237b87234b..c46c100c9a476 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
@@ -54,6 +54,10 @@ public static TextEmbeddingResults createRandomResults() {
return createRandomResults(createFunction);
}
+ public static TextEmbeddingResults createRandomFloatResults() {
+ return createRandomResults(TextEmbeddingResultsTests::createRandomFloatEmbedding);
+ }
+
private static TextEmbeddingResults createRandomResults(Supplier creator) {
int embeddings = randomIntBetween(1, 10);
List embeddingResults = new ArrayList<>(embeddings);
@@ -65,10 +69,6 @@ private static TextEmbeddingResults createRandomResults(Supplier bytes = new ArrayList<>(columns);
@@ -284,7 +284,7 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr
return new TextEmbeddingResults(instance.embeddings().subList(0, end));
} else {
List embeddings = new ArrayList<>(instance.embeddings());
- embeddings.add(createRandomEmbedding(randomBoolean()));
+ embeddings.add(createRandomFloatEmbedding());
return new TextEmbeddingResults(embeddings);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
deleted file mode 100644
index 600ddb54eddd3..0000000000000
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncationTests.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.cohere;
-
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
-
-import java.io.IOException;
-
-public class CohereTruncationTests extends AbstractWireSerializingTestCase {
- public static CohereTruncation createRandom() {
- return randomFrom(CohereTruncation.values());
- }
-
- @Override
- protected Writeable.Reader instanceReader() {
- return CohereTruncation::fromStream;
- }
-
- @Override
- protected CohereTruncation createTestInstance() {
- return createRandom();
- }
-
- @Override
- protected CohereTruncation mutateInstance(CohereTruncation instance) throws IOException {
- return null;
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
deleted file mode 100644
index 8b907746b0779..0000000000000
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.cohere.embeddings;
-
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
-
-import java.io.IOException;
-
-public class CohereEmbeddingTypeTests extends AbstractWireSerializingTestCase {
- public static CohereEmbeddingType createRandom() {
- return randomFrom(CohereEmbeddingType.values());
- }
-
- @Override
- protected Writeable.Reader instanceReader() {
- return CohereEmbeddingType::fromStream;
- }
-
- @Override
- protected CohereEmbeddingType createTestInstance() {
- return createRandom();
- }
-
- @Override
- protected CohereEmbeddingType mutateInstance(CohereEmbeddingType instance) throws IOException {
- return null;
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
index 84e65ed87e372..a568a0eaa1e5a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
@@ -19,7 +20,6 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
-import java.util.Locale;
import java.util.Map;
import static org.hamcrest.Matchers.is;
@@ -50,11 +50,11 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
CohereServiceFields.MODEL,
"abc",
CohereEmbeddingsTaskSettings.INPUT_TYPE,
- InputType.INGEST.toString().toLowerCase(Locale.ROOT),
+ InputType.INGEST.toString(),
CohereEmbeddingsTaskSettings.EMBEDDING_TYPE,
- CohereEmbeddingType.INT8,
+ CohereEmbeddingType.INT8.toString(),
CohereServiceFields.TRUNCATE,
- CohereTruncation.END.toString().toLowerCase(Locale.ROOT)
+ CohereTruncation.END.toString()
)
)
),
@@ -76,16 +76,13 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
var exception = expectThrows(
- ValidationException.class,
+ ElasticsearchStatusException.class,
() -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc"))))
);
MatcherAssert.assertThat(
exception.getMessage(),
- is(
- "Validation Failed: 1: [task_settings] Invalid type [Integer]"
- + " received for value [123]. [embedding_types] must be type [String];"
- )
+ is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]")
);
}
From 4a127ff1db8b4b054dd5772654a4d7f1c01cbfb1 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Thu, 18 Jan 2024 16:11:12 -0500
Subject: [PATCH 06/13] Removing rate limit error message
---
.../cohere/CohereResponseHandler.java | 37 ++++++-----------
.../cohere/CohereResponseHandlerTests.java | 41 +------------------
2 files changed, 14 insertions(+), 64 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
index 299cad20fb012..211550f8dbc08 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
@@ -9,7 +9,6 @@
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
@@ -18,14 +17,20 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
-import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
+/**
+ * Defines how to handle various errors returned from the Cohere integration.
+ *
+ * NOTE:
+ * These headers are returned for trial API keys only (they also do not exist within 429 responses)
+ *
+ *
+ * x-endpoint-monthly-call-limit
+ * x-trial-endpoint-call-limit
+ * x-trial-endpoint-call-remaining
+ *
+ */
public class CohereResponseHandler extends BaseResponseHandler {
-
- static final String MONTHLY_REQUESTS_LIMIT = "x-endpoint-monthly-call-limit";
- // TODO determine the production versions of these
- static final String TRIAL_REQUEST_LIMIT_PER_MINUTE = "x-trial-endpoint-call-limit";
- static final String TRIAL_REQUESTS_REMAINING = "x-trial-endpoint-call-remaining";
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
@@ -58,7 +63,7 @@ void checkForFailureStatusCode(HttpRequestBase request, HttpResult result) throw
if (statusCode >= 500) {
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 429) {
- throw new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
+ throw new RetryException(true, buildError(RATE_LIMIT, request, result));
} else if (isTextsArrayTooLarge(result)) {
throw new RetryException(false, buildError(TEXTS_ARRAY_ERROR_MESSAGE, request, result));
} else if (statusCode == 401) {
@@ -80,20 +85,4 @@ private static boolean isTextsArrayTooLarge(HttpResult result) {
return false;
}
-
- static String buildRateLimitErrorMessage(HttpResult result) {
- var response = result.response();
- var monthlyRequestLimit = getFirstHeaderOrUnknown(response, MONTHLY_REQUESTS_LIMIT);
- var trialRequestsPerMinute = getFirstHeaderOrUnknown(response, TRIAL_REQUEST_LIMIT_PER_MINUTE);
- var trialRequestsRemaining = getFirstHeaderOrUnknown(response, TRIAL_REQUESTS_REMAINING);
-
- var usageMessage = Strings.format(
- "Monthly request limit [%s], permitted requests per minute [%s], remaining requests [%s]",
- monthlyRequestLimit,
- trialRequestsPerMinute,
- trialRequestsRemaining
- );
-
- return RATE_LIMIT + ". " + usageMessage;
- }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
index 9b7edd12492ea..8873f603dd082 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java
@@ -12,7 +12,6 @@
import org.apache.http.HttpResponse;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpRequestBase;
-import org.apache.http.message.BasicHeader;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
@@ -50,7 +49,7 @@ public void testCheckForFailureStatusCode_ThrowsFor429() {
assertTrue(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
- containsString("Received a rate limit status code. Monthly request limit")
+ containsString("Received a rate limit status code for request [null] status [429]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS));
}
@@ -98,44 +97,6 @@ public void testCheckForFailureStatusCode_ThrowsFor300() {
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES));
}
- public void testBuildRateLimitErrorMessage() {
- var statusLine = mock(StatusLine.class);
- when(statusLine.getStatusCode()).thenReturn(429);
- var response = mock(HttpResponse.class);
- when(response.getStatusLine()).thenReturn(statusLine);
- var httpResult = new HttpResult(response, new byte[] {});
-
- when(response.getFirstHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT)).thenReturn(
- new BasicHeader(CohereResponseHandler.MONTHLY_REQUESTS_LIMIT, "3000")
- );
- when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE)).thenReturn(
- new BasicHeader(CohereResponseHandler.TRIAL_REQUEST_LIMIT_PER_MINUTE, "2999")
- );
- when(response.getFirstHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING)).thenReturn(
- new BasicHeader(CohereResponseHandler.TRIAL_REQUESTS_REMAINING, "12")
- );
-
- var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult);
- MatcherAssert.assertThat(
- error,
- containsString("Monthly request limit [3000], permitted requests per minute [2999], remaining requests [12]")
- );
- }
-
- public void testBuildRateLimitErrorMessage_FillsWithUnknown_WhenUnableToFindHeader() {
- var statusLine = mock(StatusLine.class);
- when(statusLine.getStatusCode()).thenReturn(429);
- var response = mock(HttpResponse.class);
- when(response.getStatusLine()).thenReturn(statusLine);
- var httpResult = new HttpResult(response, new byte[] {});
-
- var error = CohereResponseHandler.buildRateLimitErrorMessage(httpResult);
- MatcherAssert.assertThat(
- error,
- containsString("Monthly request limit [unknown], permitted requests per minute [unknown], remaining requests [unknown]")
- );
- }
-
private static void callCheckForFailureStatusCode(int statusCode) {
callCheckForFailureStatusCode(statusCode, null);
}
From 3ebed6d73729c9e55a12a86170b49c8a04f94940 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Thu, 18 Jan 2024 16:38:01 -0500
Subject: [PATCH 07/13] Fixing a few comments
---
.../inference/external/action/cohere/CohereActionCreator.java | 2 +-
.../xpack/inference/external/cohere/CohereResponseHandler.java | 1 -
.../xpack/inference/services/cohere/CohereServiceTests.java | 2 +-
3 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
index b3ac96979b68b..8c9d70f0a7323 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java
@@ -16,7 +16,7 @@
import java.util.Objects;
/**
- * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type.
+ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type.
*/
public class CohereActionCreator implements CohereActionVisitor {
private final Sender sender;
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
index 211550f8dbc08..59cf6f4ca52f6 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java
@@ -48,7 +48,6 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, H
/**
* Validates the status code throws an RetryException if not in the range [200, 300).
*
- * The OpenAI API error codes are documented here.
* @param request The http request
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 2061ac646c56b..26bc51683155c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -263,7 +263,7 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr
}
}
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException {
+ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() throws IOException {
try (
var service = new CohereService(
new SetOnce<>(mock(HttpRequestSenderFactory.class)),
From a4e576f4a3cca762693b6b95b477bcece7888394 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com>
Date: Mon, 22 Jan 2024 08:37:50 -0500
Subject: [PATCH 08/13] Update docs/changelog/104559.yaml
---
docs/changelog/104559.yaml | 5 +++++
1 file changed, 5 insertions(+)
create mode 100644 docs/changelog/104559.yaml
diff --git a/docs/changelog/104559.yaml b/docs/changelog/104559.yaml
new file mode 100644
index 0000000000000..d6d030783c4cc
--- /dev/null
+++ b/docs/changelog/104559.yaml
@@ -0,0 +1,5 @@
+pr: 104559
+summary: Adding support for Cohere inference service
+area: Machine Learning
+type: enhancement
+issues: []
From 82001b2fad754bcdd2a137dfd6a39715c136d005 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Tue, 23 Jan 2024 15:17:54 -0500
Subject: [PATCH 09/13] Addressing most feedback
---
.../InferenceNamedWriteablesProvider.java | 11 +
.../external/action/ActionUtils.java | 2 +-
.../action/cohere/CohereEmbeddingsAction.java | 29 ++-
.../action/openai/OpenAiEmbeddingsAction.java | 4 +-
.../external/request/RequestUtils.java | 7 +-
.../cohere/CohereEmbeddingsRequest.java | 16 +-
.../cohere/CohereEmbeddingsRequestEntity.java | 40 ++--
.../inference/services/ServiceUtils.java | 7 +-
.../services/cohere/CohereService.java | 15 +-
.../cohere/CohereServiceSettings.java | 55 +++--
.../embeddings/CohereEmbeddingsModel.java | 11 +-
.../CohereEmbeddingsServiceSettings.java | 111 +++++++++
.../CohereEmbeddingsTaskSettings.java | 46 +---
.../HuggingFaceServiceSettings.java | 4 +-
.../openai/OpenAiServiceSettings.java | 11 -
.../cohere/CohereActionCreatorTests.java | 17 +-
.../cohere/CohereEmbeddingsActionTests.java | 32 ++-
.../CohereEmbeddingsRequestEntityTests.java | 10 +-
.../cohere/CohereEmbeddingsRequestTests.java | 20 +-
.../results/TextEmbeddingResultsTests.java | 5 +-
.../inference/services/ServiceUtilsTests.java | 5 +-
.../cohere/CohereServiceSettingsTests.java | 19 +-
.../services/cohere/CohereServiceTests.java | 214 +++++++++++-------
.../CohereEmbeddingsModelTests.java | 59 ++++-
.../CohereEmbeddingsServiceSettingsTests.java | 167 ++++++++++++++
.../CohereEmbeddingsTaskSettingsTests.java | 45 +---
26 files changed, 659 insertions(+), 303 deletions(-)
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index 02d19ba60b0e7..f2f604e22bc24 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
@@ -98,6 +99,16 @@ public static List getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(CohereServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ ServiceSettings.class,
+ CohereEmbeddingsServiceSettings.NAME,
+ CohereEmbeddingsServiceSettings::new
+ )
+ );
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
index 4a8519934c63c..e4d6e39fdf1f2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java
@@ -39,7 +39,7 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e);
}
- public static String getErrorMessage(@Nullable URI uri, String message) {
+ public static String constructFailedToSendRequestMessage(@Nullable URI uri, String message) {
if (uri != null) {
return Strings.format("Failed to send %s request to [%s]", message, uri);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
index 62afbc190d573..55611ab06a641 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
@@ -27,8 +27,8 @@
import java.util.List;
import java.util.Objects;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class CohereEmbeddingsAction implements ExecutableAction {
@@ -37,13 +37,19 @@ public class CohereEmbeddingsAction implements ExecutableAction {
private final CohereAccount account;
private final CohereEmbeddingsModel model;
- private final String errorMessage;
+ private final String failedToSendRequestErrorMessage;
private final RetryingHttpSender sender;
public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
- this.account = new CohereAccount(this.model.getServiceSettings().uri(), this.model.getSecretSettings().apiKey());
- this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "Cohere embeddings");
+ this.account = new CohereAccount(
+ this.model.getServiceSettings().getCommonSettings().getUri(),
+ this.model.getSecretSettings().apiKey()
+ );
+ this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
+ this.model.getServiceSettings().getCommonSettings().getUri(),
+ "Cohere embeddings"
+ );
this.sender = new RetryingHttpSender(
Objects.requireNonNull(sender),
serviceComponents.throttlerManager(),
@@ -56,14 +62,23 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Servic
@Override
public void execute(List input, ActionListener listener) {
try {
- CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(account, input, model.getTaskSettings());
- ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
+ CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(
+ account,
+ input,
+ model.getTaskSettings(),
+ model.getServiceSettings().getCommonSettings().getModel(),
+ model.getServiceSettings().getEmbeddingType()
+ );
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(
+ failedToSendRequestErrorMessage,
+ listener
+ );
sender.send(request, HANDLER, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
- listener.onFailure(createInternalServerError(e, errorMessage));
+ listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
index 3877157188832..f7bcd53724168 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsAction.java
@@ -23,8 +23,8 @@
import java.util.Objects;
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.getErrorMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class OpenAiEmbeddingsAction implements ExecutableAction {
@@ -43,7 +43,7 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic
this.model.getSecretSettings().apiKey()
);
this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents));
- this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings");
+ this.errorMessage = constructFailedToSendRequestMessage(this.model.getServiceSettings().uri(), "OpenAI embeddings");
this.truncator = Objects.requireNonNull(serviceComponents.truncator());
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
index 4cb32b7bc95fd..6116b1cc234c6 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java
@@ -29,11 +29,8 @@ public static URI buildUri(URI accountUri, String service, CheckedSupplier input;
private final URI uri;
private final CohereEmbeddingsTaskSettings taskSettings;
+ private final String model;
+ private final CohereEmbeddingType embeddingType;
- public CohereEmbeddingsRequest(CohereAccount account, List input, CohereEmbeddingsTaskSettings taskSettings) {
+ public CohereEmbeddingsRequest(
+ CohereAccount account,
+ List input,
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
+ ) {
this.account = Objects.requireNonNull(account);
this.input = Objects.requireNonNull(input);
this.uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
this.taskSettings = Objects.requireNonNull(taskSettings);
+ this.model = model;
+ this.embeddingType = embeddingType;
}
@Override
@@ -46,7 +58,7 @@ public HttpRequestBase createRequest() {
HttpPost httpPost = new HttpPost(uri);
ByteArrayEntity byteEntity = new ByteArrayEntity(
- Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8)
+ Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
index 831fb07b30e24..a7d8743359f39 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
@@ -7,30 +7,36 @@
package org.elasticsearch.xpack.inference.external.request.cohere;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import java.io.IOException;
import java.util.List;
-import java.util.Map;
import java.util.Objects;
-public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddingsTaskSettings taskSettings) implements ToXContentObject {
+public record CohereEmbeddingsRequestEntity(
+ List input,
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
+) implements ToXContentObject {
private static final String SEARCH_DOCUMENT = "search_document";
private static final String SEARCH_QUERY = "search_query";
/**
- * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request
+ * Maps the {@link InputType} to the expected value for cohere for the input_type field in the request using the enum's ordinal.
+ * The order of these entries is important and needs to match the order in the enum
*/
- private static final Map INPUT_TYPE_MAPPING = Map.of(
- InputType.INGEST,
- SEARCH_DOCUMENT,
- InputType.SEARCH,
- SEARCH_QUERY
- );
+ private static final String[] INPUT_TYPE_MAPPING = { SEARCH_DOCUMENT, SEARCH_QUERY };
+ static {
+ assert INPUT_TYPE_MAPPING.length == InputType.values().length : "input type mapping was incorrectly defined";
+ }
+
private static final String TEXTS_FIELD = "texts";
static final String INPUT_TYPE_FIELD = "input_type";
@@ -45,16 +51,16 @@ public record CohereEmbeddingsRequestEntity(List input, CohereEmbeddings
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TEXTS_FIELD, input);
- if (taskSettings.model() != null) {
- builder.field(CohereServiceFields.MODEL, taskSettings.model());
+ if (model != null) {
+ builder.field(CohereServiceFields.MODEL, model);
}
if (taskSettings.inputType() != null) {
builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType()));
}
- if (taskSettings.embeddingType() != null) {
- builder.field(EMBEDDING_TYPES_FIELD, List.of(taskSettings.embeddingType()));
+ if (embeddingType != null) {
+ builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType));
}
if (taskSettings.truncation() != null) {
@@ -66,12 +72,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
private static String covertToString(InputType inputType) {
- var stringValue = INPUT_TYPE_MAPPING.get(inputType);
-
- if (stringValue == null) {
- return SEARCH_DOCUMENT;
- }
-
- return stringValue;
+ return INPUT_TYPE_MAPPING[inputType.ordinal()];
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 3e9f6e1f75a8a..20912a3811af1 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -12,6 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
@@ -138,8 +139,12 @@ public static String invalidValue(String settingName, String scope, String inval
}
// TODO improve URI validation logic
- public static URI convertToUri(String url, String settingName, String settingScope, ValidationException validationException) {
+ public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
try {
+ if (url == null) {
+ return null;
+ }
+
return createUri(url);
} catch (IllegalArgumentException ignored) {
validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope));
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
index e654316a4fbd4..8783f12852ec8 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import java.util.List;
import java.util.Map;
@@ -157,11 +158,15 @@ public void checkModelConfig(Model model, ActionListener listener) {
}
private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
- CohereServiceSettings serviceSettings = new CohereServiceSettings(
- model.getServiceSettings().uri(),
- SimilarityMeasure.DOT_PRODUCT,
- embeddingSize,
- model.getServiceSettings().maxInputTokens()
+ CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings(
+ new CohereServiceSettings(
+ model.getServiceSettings().getCommonSettings().getUri(),
+ SimilarityMeasure.DOT_PRODUCT,
+ embeddingSize,
+ model.getServiceSettings().getCommonSettings().getMaxInputTokens(),
+ model.getServiceSettings().getCommonSettings().getModel()
+ ),
+ model.getServiceSettings().getEmbeddingType()
);
return new CohereEmbeddingsModel(model, serviceSettings);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
index bdd5e981b0490..f03371593a340 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
@@ -35,6 +35,7 @@
public class CohereServiceSettings implements ServiceSettings {
public static final String NAME = "cohere_service_settings";
+ public static final String MODEL = "model";
public static CohereServiceSettings fromMap(Map map) {
ValidationException validationException = new ValidationException();
@@ -44,50 +45,44 @@ public static CohereServiceSettings fromMap(Map map) {
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
-
- // Throw if any of the settings were empty strings or invalid
- if (validationException.validationErrors().isEmpty() == false) {
- throw validationException;
- }
-
- // the url is optional and only for testing
- if (url == null) {
- return new CohereServiceSettings((URI) null, similarity, dims, maxInputTokens);
- }
-
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
- return new CohereServiceSettings(uri, similarity, dims, maxInputTokens);
+ return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, model);
}
private final URI uri;
private final SimilarityMeasure similarity;
private final Integer dimensions;
private final Integer maxInputTokens;
+ private final String model;
public CohereServiceSettings(
@Nullable URI uri,
@Nullable SimilarityMeasure similarity,
@Nullable Integer dimensions,
- @Nullable Integer maxInputTokens
+ @Nullable Integer maxInputTokens,
+ @Nullable String model
) {
this.uri = uri;
this.similarity = similarity;
this.dimensions = dimensions;
this.maxInputTokens = maxInputTokens;
+ this.model = model;
}
public CohereServiceSettings(
@Nullable String url,
@Nullable SimilarityMeasure similarity,
@Nullable Integer dimensions,
- @Nullable Integer maxInputTokens
+ @Nullable Integer maxInputTokens,
+ @Nullable String model
) {
- this(createOptionalUri(url), similarity, dimensions, maxInputTokens);
+ this(createOptionalUri(url), similarity, dimensions, maxInputTokens, model);
}
public CohereServiceSettings(StreamInput in) throws IOException {
@@ -95,24 +90,29 @@ public CohereServiceSettings(StreamInput in) throws IOException {
similarity = in.readOptionalEnum(SimilarityMeasure.class);
dimensions = in.readOptionalVInt();
maxInputTokens = in.readOptionalVInt();
+ model = in.readOptionalString();
}
- public URI uri() {
+ public URI getUri() {
return uri;
}
- public SimilarityMeasure similarity() {
+ public SimilarityMeasure getSimilarity() {
return similarity;
}
- public Integer dimensions() {
+ public Integer getDimensions() {
return dimensions;
}
- public Integer maxInputTokens() {
+ public Integer getMaxInputTokens() {
return maxInputTokens;
}
+ public String getModel() {
+ return model;
+ }
+
@Override
public String getWriteableName() {
return NAME;
@@ -122,6 +122,13 @@ public String getWriteableName() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
+ toXContentFragment(builder);
+
+ builder.endObject();
+ return builder;
+ }
+
+ public XContentBuilder toXContentFragment(XContentBuilder builder) throws IOException {
if (uri != null) {
builder.field(URL, uri.toString());
}
@@ -134,8 +141,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
+ if (model != null) {
+ builder.field(MODEL, model);
+ }
- builder.endObject();
return builder;
}
@@ -151,6 +160,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(similarity);
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
+ out.writeOptionalString(model);
}
@Override
@@ -161,11 +171,12 @@ public boolean equals(Object o) {
return Objects.equals(uri, that.uri)
&& Objects.equals(similarity, that.similarity)
&& Objects.equals(dimensions, that.dimensions)
- && Objects.equals(maxInputTokens, that.maxInputTokens);
+ && Objects.equals(maxInputTokens, that.maxInputTokens)
+ && Objects.equals(model, that.model);
}
@Override
public int hashCode() {
- return Objects.hash(uri, similarity, dimensions, maxInputTokens);
+ return Objects.hash(uri, similarity, dimensions, maxInputTokens, model);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
index accba9149a46e..c92700e87cd96 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java
@@ -14,7 +14,6 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
-import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.Map;
@@ -32,7 +31,7 @@ public CohereEmbeddingsModel(
modelId,
taskType,
service,
- CohereServiceSettings.fromMap(serviceSettings),
+ CohereEmbeddingsServiceSettings.fromMap(serviceSettings),
CohereEmbeddingsTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
@@ -43,7 +42,7 @@ public CohereEmbeddingsModel(
String modelId,
TaskType taskType,
String service,
- CohereServiceSettings serviceSettings,
+ CohereEmbeddingsServiceSettings serviceSettings,
CohereEmbeddingsTaskSettings taskSettings,
@Nullable DefaultSecretSettings secretSettings
) {
@@ -54,13 +53,13 @@ private CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsTaskS
super(model, taskSettings);
}
- public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereServiceSettings serviceSettings) {
+ public CohereEmbeddingsModel(CohereEmbeddingsModel model, CohereEmbeddingsServiceSettings serviceSettings) {
super(model, serviceSettings);
}
@Override
- public CohereServiceSettings getServiceSettings() {
- return (CohereServiceSettings) super.getServiceSettings();
+ public CohereEmbeddingsServiceSettings getServiceSettings() {
+ return (CohereEmbeddingsServiceSettings) super.getServiceSettings();
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
new file mode 100644
index 0000000000000..f8400bc168d69
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
+
+public class CohereEmbeddingsServiceSettings implements ServiceSettings {
+ public static final String NAME = "cohere_embeddings_service_settings";
+
+ static final String EMBEDDING_TYPE = "embedding_type";
+
+ public static CohereEmbeddingsServiceSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+ var commonServiceSettings = CohereServiceSettings.fromMap(map);
+ CohereEmbeddingType embeddingTypes = extractOptionalEnum(
+ map,
+ EMBEDDING_TYPE,
+ ModelConfigurations.SERVICE_SETTINGS,
+ CohereEmbeddingType::fromString,
+ CohereEmbeddingType.values(),
+ validationException
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CohereEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes);
+ }
+
+ private final CohereServiceSettings commonSettings;
+ private final CohereEmbeddingType embeddingType;
+
+ public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, @Nullable CohereEmbeddingType embeddingType) {
+ this.commonSettings = commonSettings;
+ this.embeddingType = embeddingType;
+ }
+
+ public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException {
+ commonSettings = in.readNamedWriteable(CohereServiceSettings.class);
+ embeddingType = in.readOptionalEnum(CohereEmbeddingType.class);
+ }
+
+ public CohereServiceSettings getCommonSettings() {
+ return commonSettings;
+ }
+
+ public CohereEmbeddingType getEmbeddingType() {
+ return embeddingType;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ commonSettings.toXContentFragment(builder);
+ builder.field(EMBEDDING_TYPE, embeddingType);
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeNamedWriteable(commonSettings);
+ out.writeOptionalEnum(embeddingType);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CohereEmbeddingsServiceSettings that = (CohereEmbeddingsServiceSettings) o;
+ return Objects.equals(commonSettings, that.commonSettings) && Objects.equals(embeddingType, that.embeddingType);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(commonSettings, embeddingType);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
index 33bd4095c47cd..858efdb0d1ace 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java
@@ -23,8 +23,6 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.MODEL;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE;
/**
@@ -34,22 +32,14 @@
* See api docs for details.
*
*
- * @param model the id of the model to use in the requests to cohere
* @param inputType Specifies the type of input you're giving to the model
- * @param embeddingType Specifies the type of embeddings you want to get back (we only support retrieving a single type)
* @param truncation Specifies how the API will handle inputs longer than the maximum token length
*/
-public record CohereEmbeddingsTaskSettings(
- @Nullable String model,
- @Nullable InputType inputType,
- @Nullable CohereEmbeddingType embeddingType,
- @Nullable CohereTruncation truncation
-) implements TaskSettings {
+public record CohereEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable CohereTruncation truncation) implements TaskSettings {
public static final String NAME = "cohere_embeddings_task_settings";
- public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null, null, null);
+ public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null);
static final String INPUT_TYPE = "input_type";
- static final String EMBEDDING_TYPE = "embedding_type";
public static CohereEmbeddingsTaskSettings fromMap(Map map) {
if (map.isEmpty()) {
@@ -58,7 +48,6 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
ValidationException validationException = new ValidationException();
- String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
InputType inputType = extractOptionalEnum(
map,
INPUT_TYPE,
@@ -67,14 +56,6 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
InputType.values(),
validationException
);
- CohereEmbeddingType embeddingTypes = extractOptionalEnum(
- map,
- EMBEDDING_TYPE,
- ModelConfigurations.TASK_SETTINGS,
- CohereEmbeddingType::fromString,
- CohereEmbeddingType.values(),
- validationException
- );
CohereTruncation truncation = extractOptionalEnum(
map,
TRUNCATE,
@@ -88,33 +69,20 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) {
throw validationException;
}
- return new CohereEmbeddingsTaskSettings(model, inputType, embeddingTypes, truncation);
+ return new CohereEmbeddingsTaskSettings(inputType, truncation);
}
public CohereEmbeddingsTaskSettings(StreamInput in) throws IOException {
- this(
- in.readOptionalString(),
- in.readOptionalEnum(InputType.class),
- in.readOptionalEnum(CohereEmbeddingType.class),
- in.readOptionalEnum(CohereTruncation.class)
- );
+ this(in.readOptionalEnum(InputType.class), in.readOptionalEnum(CohereTruncation.class));
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
- if (model != null) {
- builder.field(MODEL, model);
- }
-
if (inputType != null) {
builder.field(INPUT_TYPE, inputType);
}
- if (embeddingType != null) {
- builder.field(EMBEDDING_TYPE, embeddingType);
- }
-
if (truncation != null) {
builder.field(TRUNCATE, truncation);
}
@@ -134,18 +102,14 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeOptionalString(model);
out.writeOptionalEnum(inputType);
- out.writeOptionalEnum(embeddingType);
out.writeOptionalEnum(truncation);
}
public CohereEmbeddingsTaskSettings overrideWith(CohereEmbeddingsTaskSettings requestTaskSettings) {
- var modelToUse = requestTaskSettings.model() == null ? model : requestTaskSettings.model();
var inputTypeToUse = requestTaskSettings.inputType() == null ? inputType : requestTaskSettings.inputType();
- var embeddingTypesToUse = requestTaskSettings.embeddingType() == null ? embeddingType : requestTaskSettings.embeddingType();
var truncationToUse = requestTaskSettings.truncation() == null ? truncation : requestTaskSettings.truncation();
- return new CohereEmbeddingsTaskSettings(modelToUse, inputTypeToUse, embeddingTypesToUse, truncationToUse);
+ return new CohereEmbeddingsTaskSettings(inputTypeToUse, truncationToUse);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java
index 6464ca0e0fda8..b3b130b22a1fa 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java
@@ -52,9 +52,7 @@ public static HuggingFaceServiceSettings fromMap(Map map) {
public static URI extractUri(Map map, String fieldName, ValidationException validationException) {
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
- if (parsedUrl == null) {
- return null;
- }
+
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
index 553a5eaf60dae..4e96ac73157ad 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java
@@ -50,17 +50,6 @@ public static OpenAiServiceSettings fromMap(Map map) {
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
-
- // Throw if any of the settings were empty strings or invalid
- if (validationException.validationErrors().isEmpty() == false) {
- throw validationException;
- }
-
- // the url is optional and only for testing
- if (url == null) {
- return new OpenAiServiceSettings((URI) null, organizationId, similarity, dims, maxInputTokens);
- }
-
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
index e1daa78454d25..acd0145cbad24 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
@@ -66,7 +66,7 @@ public void shutdown() throws IOException {
webServer.close();
}
- public void testCreate_OpenAiEmbeddingsModel() throws IOException {
+ public void testCreate_CohereEmbeddingsModel() throws IOException {
var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager, mockClusterServiceEmpty(), Settings.EMPTY);
try (var sender = senderFactory.createSender("test_service")) {
@@ -102,17 +102,14 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException {
var model = CohereEmbeddingsModelTests.createModel(
getUrl(webServer),
"secret",
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
+ 1024,
1024,
- 1024
- );
- var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
- var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(
"model",
- InputType.SEARCH,
- CohereEmbeddingType.INT8,
- CohereTruncation.END
+ CohereEmbeddingType.FLOAT
);
+ var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
+ var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END);
var action = actionCreator.create(model, overriddenTaskSettings);
PlainActionFuture listener = new PlainActionFuture<>();
@@ -141,7 +138,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException {
"input_type",
"search_query",
"embedding_types",
- List.of("int8"),
+ List.of("float"),
"truncate",
"end"
)
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
index d2a8aad19b4c8..86ba29c95b4bc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
@@ -12,6 +12,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -110,7 +111,9 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var action = createAction(
getUrl(webServer),
"secret",
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
+ "model",
+ CohereEmbeddingType.FLOAT,
sender
);
@@ -185,7 +188,9 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I
var action = createAction(
getUrl(webServer),
"secret",
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.START),
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
+ "model",
+ CohereEmbeddingType.INT8,
sender
);
@@ -228,7 +233,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept
try (var sender = mock(Sender.class)) {
var thrownException = expectThrows(
IllegalArgumentException.class,
- () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender)
+ () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender)
);
MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]"));
}
@@ -238,7 +243,7 @@ public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any());
- var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
PlainActionFuture listener = new PlainActionFuture<>();
action.execute(List.of("abc"), listener);
@@ -259,7 +264,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
return Void.TYPE;
}).when(sender).send(any(), any());
- var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
PlainActionFuture listener = new PlainActionFuture<>();
action.execute(List.of("abc"), listener);
@@ -283,7 +288,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
return Void.TYPE;
}).when(sender).send(any(), any());
- var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+ var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
PlainActionFuture listener = new PlainActionFuture<>();
action.execute(List.of("abc"), listener);
@@ -297,7 +302,7 @@ public void testExecute_ThrowsException() {
var sender = mock(Sender.class);
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any());
- var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+ var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
PlainActionFuture listener = new PlainActionFuture<>();
action.execute(List.of("abc"), listener);
@@ -314,7 +319,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() {
var sender = mock(Sender.class);
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any());
- var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, sender);
+ var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
PlainActionFuture listener = new PlainActionFuture<>();
action.execute(List.of("abc"), listener);
@@ -324,8 +329,15 @@ public void testExecute_ThrowsExceptionWithNullUrl() {
MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request"));
}
- private CohereEmbeddingsAction createAction(String url, String apiKey, CohereEmbeddingsTaskSettings taskSettings, Sender sender) {
- var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024);
+ private CohereEmbeddingsAction createAction(
+ String url,
+ String apiKey,
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable String modelName,
+ @Nullable CohereEmbeddingType embeddingType,
+ Sender sender
+ ) {
+ var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType);
return new CohereEmbeddingsAction(sender, model, createWithEmptySettings(threadPool));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
index 01a2290446529..8ef9ea4b0316b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
@@ -27,7 +27,9 @@ public class CohereEmbeddingsRequestEntityTests extends ESTestCase {
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(
List.of("abc"),
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
+ "model",
+ CohereEmbeddingType.FLOAT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -41,7 +43,9 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException
public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(
List.of("abc"),
- new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE)
+ new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
+ "model",
+ CohereEmbeddingType.INT8
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -53,7 +57,7 @@ public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throw
}
public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
- var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
+ var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
index 903f6a9500831..a8a661149b120 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
@@ -32,7 +32,7 @@
public class CohereEmbeddingsRequestTests extends ESTestCase {
public void testCreateRequest_UrlDefined() throws URISyntaxException, IOException {
- var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
+ var request = createRequest("url", "secret", List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);
var httpRequest = request.createRequest();
MatcherAssert.assertThat(httpRequest, instanceOf(HttpPost.class));
@@ -52,7 +52,9 @@ public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOE
"url",
"secret",
List.of("abc"),
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START)
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
+ "model",
+ CohereEmbeddingType.FLOAT
);
var httpRequest = request.createRequest();
@@ -89,7 +91,9 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th
"url",
"secret",
List.of("abc"),
- new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.END)
+ new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END),
+ "model",
+ CohereEmbeddingType.INT8
);
var httpRequest = request.createRequest();
@@ -126,7 +130,9 @@ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOExcept
"url",
"secret",
List.of("abc"),
- new CohereEmbeddingsTaskSettings(null, null, null, CohereTruncation.NONE)
+ new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE),
+ null,
+ null
);
var httpRequest = request.createRequest();
@@ -146,11 +152,13 @@ public static CohereEmbeddingsRequest createRequest(
@Nullable String url,
String apiKey,
List input,
- CohereEmbeddingsTaskSettings taskSettings
+ CohereEmbeddingsTaskSettings taskSettings,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
) throws URISyntaxException {
var uri = url == null ? null : new URI(url);
var account = new CohereAccount(uri, new SecureString(apiKey.toCharArray()));
- return new CohereEmbeddingsRequest(account, input, taskSettings);
+ return new CohereEmbeddingsRequest(account, input, taskSettings, model, embeddingType);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
index c46c100c9a476..9c154a20531f7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
@@ -46,10 +46,7 @@ private enum EmbeddingType {
public static TextEmbeddingResults createRandomResults() {
var embeddingType = randomFrom(EmbeddingType.values());
var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType);
-
- if (createFunction == null) {
- createFunction = TextEmbeddingResultsTests::createRandomFloatEmbedding;
- }
+ assert createFunction != null : "the embeddings type map is missing a value from the EmbeddingType enum";
return createRandomResults(createFunction);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
index eb54745806a68..c72b161941ad2 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
@@ -105,10 +105,11 @@ public void testConvertToUri_CreatesUri() {
assertThat(uri.toString(), is("www.elastic.co"));
}
- public void testConvertToUri_ThrowsNullPointerException_WhenPassedNull() {
+ public void testConvertToUri_DoesNotThrowNullPointerException_WhenPassedNull() {
var validation = new ValidationException();
- expectThrows(NullPointerException.class, () -> convertToUri(null, "name", "scope", validation));
+ var uri = convertToUri(null, "name", "scope", validation);
+ assertNull(uri);
assertTrue(validation.validationErrors().isEmpty());
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
index 21558eec87302..8f829c11c0a11 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
@@ -47,7 +47,9 @@ private static CohereServiceSettings createRandom(String url) {
dims = 1536;
}
Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
- return new CohereServiceSettings(ServiceUtils.createUri(url), similarityMeasure, dims, maxInputTokens);
+ var model = randomBoolean() ? randomAlphaOfLength(15) : null;
+
+ return new CohereServiceSettings(ServiceUtils.createOptionalUri(url), similarityMeasure, dims, maxInputTokens, model);
}
public void testFromMap() {
@@ -55,6 +57,7 @@ public void testFromMap() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var dims = 1536;
var maxInputTokens = 512;
+ var model = "model";
var serviceSettings = CohereServiceSettings.fromMap(
new HashMap<>(
Map.of(
@@ -65,20 +68,22 @@ public void testFromMap() {
ServiceFields.DIMENSIONS,
dims,
ServiceFields.MAX_INPUT_TOKENS,
- maxInputTokens
+ maxInputTokens,
+ CohereServiceSettings.MODEL,
+ model
)
)
);
MatcherAssert.assertThat(
serviceSettings,
- is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens))
+ is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model))
);
}
public void testFromMap_MissingUrl_DoesNotThrowException() {
var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of()));
- assertNull(serviceSettings.uri());
+ assertNull(serviceSettings.getUri());
}
public void testFromMap_EmptyUrl_ThrowsError() {
@@ -139,13 +144,17 @@ protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) t
return createRandomWithNonNullUrl();
}
- public static Map getServiceSettingsMap(@Nullable String url) {
+ public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) {
var map = new HashMap();
if (url != null) {
map.put(ServiceFields.URL, url);
}
+ if (model != null) {
+ map.put(CohereServiceSettings.MODEL, model);
+ }
+
return map;
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 26bc51683155c..dc3c8eafb46ac 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -32,6 +32,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
@@ -52,7 +53,6 @@
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
-import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests.getServiceSettingsMap;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@@ -99,8 +99,8 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
- getServiceSettingsMap("url"),
- getTaskSettingsMap("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", CohereEmbeddingType.FLOAT),
+ getTaskSettingsMap(InputType.INGEST, CohereTruncation.START),
getSecretSettingsMap("secret")
),
Set.of()
@@ -109,10 +109,12 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
MatcherAssert.assertThat(
embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, CohereEmbeddingType.FLOAT, CohereTruncation.START))
+ is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START))
);
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
@@ -130,7 +132,11 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti
() -> service.parseRequestConfig(
"id",
TaskType.SPARSE_EMBEDDING,
- getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")),
+ getRequestConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty(),
+ getSecretSettingsMap("secret")
+ ),
Set.of()
)
);
@@ -149,7 +155,11 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret"));
+ var config = getRequestConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty(),
+ getSecretSettingsMap("secret")
+ );
config.put("extra_key", "value");
var thrownException = expectThrows(
@@ -171,14 +181,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var serviceSettings = getServiceSettingsMap("url");
+ var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null);
serviceSettings.put("extra_key", "value");
- var config = getRequestConfigMap(
- serviceSettings,
- getTaskSettingsMap("model", null, null, null),
- getSecretSettingsMap("secret")
- );
+ var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
var thrownException = expectThrows(
ElasticsearchStatusException.class,
@@ -199,10 +205,14 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap()
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var taskSettingsMap = getTaskSettingsMap("model", null, null, null);
+ var taskSettingsMap = getTaskSettingsMap(InputType.INGEST, null);
taskSettingsMap.put("extra_key", "value");
- var config = getRequestConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret"));
+ var config = getRequestConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ taskSettingsMap,
+ getSecretSettingsMap("secret")
+ );
var thrownException = expectThrows(
ElasticsearchStatusException.class,
@@ -226,7 +236,11 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap
var secretSettingsMap = getSecretSettingsMap("secret");
secretSettingsMap.put("extra_key", "value");
- var config = getRequestConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap);
+ var config = getRequestConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty(),
+ secretSettingsMap
+ );
var thrownException = expectThrows(
ElasticsearchStatusException.class,
@@ -250,14 +264,18 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr
var model = service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(getServiceSettingsMap(null), getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")),
+ getRequestConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null),
+ getTaskSettingsMapEmpty(),
+ getSecretSettingsMap("secret")
+ ),
Set.of()
);
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
@@ -271,8 +289,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel()
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("url"),
- getTaskSettingsMap("model", null, null, null),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ getTaskSettingsMap(null, null),
getSecretSettingsMap("secret")
);
@@ -286,8 +304,9 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel()
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null)));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
@@ -300,7 +319,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("url"),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
getTaskSettingsMapEmpty(),
getSecretSettingsMap("secret")
);
@@ -330,8 +349,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap(null),
- getTaskSettingsMap(null, InputType.INGEST, null, null),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null),
+ getTaskSettingsMap(InputType.INGEST, null),
getSecretSettingsMap("secret")
);
@@ -345,11 +364,8 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings(null, InputType.INGEST, null, null))
- );
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
@@ -362,8 +378,8 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("url"),
- getTaskSettingsMap("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", CohereEmbeddingType.INT8),
+ getTaskSettingsMap(InputType.SEARCH, CohereTruncation.NONE),
getSecretSettingsMap("secret")
);
persistedConfig.config().put("extra_key", "value");
@@ -378,10 +394,12 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.INT8));
MatcherAssert.assertThat(
embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, CohereEmbeddingType.INT8, CohereTruncation.NONE))
+ is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE))
);
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
@@ -397,7 +415,11 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
var secretSettingsMap = getSecretSettingsMap("secret");
secretSettingsMap.put("extra_key", "value");
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty(), secretSettingsMap);
+ var persistedConfig = getPersistedConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty(),
+ secretSettingsMap
+ );
var model = service.parsePersistedConfigWithSecrets(
"id",
@@ -409,7 +431,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
@@ -423,8 +445,8 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("url"),
- getTaskSettingsMap("model", null, null, null),
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ getTaskSettingsMap(null, null),
getSecretSettingsMap("secret")
);
persistedConfig.secrets.put("extra_key", "value");
@@ -439,8 +461,9 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings("model", null, null, null)));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
@@ -452,7 +475,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var serviceSettingsMap = getServiceSettingsMap("url");
+ var serviceSettingsMap = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null);
serviceSettingsMap.put("extra_key", "value");
var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret"));
@@ -467,7 +490,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
@@ -480,10 +503,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var taskSettingsMap = getTaskSettingsMap("model", InputType.SEARCH, null, null);
+ var taskSettingsMap = getTaskSettingsMap(InputType.SEARCH, null);
taskSettingsMap.put("extra_key", "value");
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap, getSecretSettingsMap("secret"));
+ var persistedConfig = getPersistedConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ taskSettingsMap,
+ getSecretSettingsMap("secret")
+ );
var model = service.parsePersistedConfigWithSecrets(
"id",
@@ -495,11 +522,9 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", InputType.SEARCH, null, null))
- );
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
@@ -512,8 +537,8 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("url"),
- getTaskSettingsMap("model", null, null, CohereTruncation.NONE)
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ getTaskSettingsMap(null, CohereTruncation.NONE)
);
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -521,11 +546,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.NONE))
- );
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE)));
assertNull(embeddingsModel.getSecretSettings());
}
}
@@ -537,7 +560,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty());
+ var persistedConfig = getPersistedConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty()
+ );
var thrownException = expectThrows(
ElasticsearchStatusException.class,
@@ -551,7 +577,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
}
}
- public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl() throws IOException {
+ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException {
try (
var service = new CohereService(
new SetOnce<>(mock(HttpRequestSenderFactory.class)),
@@ -559,8 +585,8 @@ public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl()
)
) {
var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap(null),
- getTaskSettingsMap("model", null, CohereEmbeddingType.FLOAT, null)
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", CohereEmbeddingType.FLOAT),
+ getTaskSettingsMap(null, null)
);
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -568,11 +594,10 @@ public void testParsePersistedConfig_CreatesAnCohereEmbeddingsModelWithoutUrl()
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", null, CohereEmbeddingType.FLOAT, null))
- );
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
assertNull(embeddingsModel.getSecretSettings());
}
}
@@ -584,7 +609,10 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig()
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), getTaskSettingsMapEmpty());
+ var persistedConfig = getPersistedConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null),
+ getTaskSettingsMapEmpty()
+ );
persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -592,7 +620,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig()
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
assertNull(embeddingsModel.getSecretSettings());
}
@@ -605,21 +633,18 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var serviceSettingsMap = getServiceSettingsMap("url");
+ var serviceSettingsMap = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null);
serviceSettingsMap.put("extra_key", "value");
- var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(null, InputType.SEARCH, null, null));
+ var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(InputType.SEARCH, null));
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings(null, InputType.SEARCH, null, null))
- );
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
assertNull(embeddingsModel.getSecretSettings());
}
}
@@ -631,21 +656,22 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings(
new SetOnce<>(createWithEmptySettings(threadPool))
)
) {
- var taskSettingsMap = getTaskSettingsMap("model", InputType.INGEST, null, null);
+ var taskSettingsMap = getTaskSettingsMap(InputType.INGEST, null);
taskSettingsMap.put("extra_key", "value");
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), taskSettingsMap);
+ var persistedConfig = getPersistedConfigMap(
+ CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", null),
+ taskSettingsMap
+ );
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
var embeddingsModel = (CohereEmbeddingsModel) model;
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- MatcherAssert.assertThat(
- embeddingsModel.getTaskSettings(),
- is(new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null))
- );
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModel(), is("model"));
+ MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
assertNull(embeddingsModel.getSecretSettings());
}
}
@@ -712,9 +738,11 @@ public void testInfer_SendsRequest() throws IOException {
var model = CohereEmbeddingsModelTests.createModel(
getUrl(webServer),
"secret",
- new CohereEmbeddingsTaskSettings("model", InputType.INGEST, null, null),
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, null),
1024,
- 1024
+ 1024,
+ "model",
+ null
);
PlainActionFuture listener = new PlainActionFuture<>();
service.infer(model, List.of("abc"), new HashMap<>(), listener);
@@ -772,7 +800,9 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException {
"secret",
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
10,
- 1
+ 1,
+ null,
+ null
);
PlainActionFuture listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);
@@ -781,7 +811,17 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException {
MatcherAssert.assertThat(
result,
// the dimension is set to 2 because there are 2 embeddings returned from the mock server
- is(CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2))
+ is(
+ CohereEmbeddingsModelTests.createModel(
+ getUrl(webServer),
+ "secret",
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+ 10,
+ 2,
+ null,
+ null
+ )
+ )
);
}
}
@@ -803,7 +843,9 @@ public void testInfer_UnauthorisedResponse() throws IOException {
"secret",
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
1024,
- 1024
+ 1024,
+ null,
+ null
);
PlainActionFuture listener = new PlainActionFuture<>();
service.infer(model, List.of("abc"), new HashMap<>(), listener);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
index d6b01486ae3ec..1961d6b168d54 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
@@ -9,6 +9,7 @@
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
@@ -24,34 +25,63 @@
public class CohereEmbeddingsModelTests extends ESTestCase {
- public void testOverrideWith_OverridesModel() {
- var model = createModel("url", "api_key", null);
+ public void testOverrideWith_OverridesInputType_WithSearch() {
+ var model = createModel(
+ "url",
+ "api_key",
+ new CohereEmbeddingsTaskSettings(InputType.INGEST, null),
+ null,
+ null,
+ "model",
+ CohereEmbeddingType.FLOAT
+ );
- var overriddenModel = model.overrideWith(getTaskSettingsMap("model", null, null, null));
- var expectedModel = createModel("url", "api_key", new CohereEmbeddingsTaskSettings("model", null, null, null), null, null);
+ var overriddenModel = model.overrideWith(getTaskSettingsMap(InputType.SEARCH, null));
+ var expectedModel = createModel(
+ "url",
+ "api_key",
+ new CohereEmbeddingsTaskSettings(InputType.SEARCH, null),
+ null,
+ null,
+ "model",
+ CohereEmbeddingType.FLOAT
+ );
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_DoesNotOverride_WhenSettingsAreEmpty() {
- var model = createModel("url", "api_key", null);
+ var model = createModel("url", "api_key", null, null, null);
var overriddenModel = model.overrideWith(Map.of());
MatcherAssert.assertThat(overriddenModel, sameInstance(model));
}
public void testOverrideWith_DoesNotOverride_WhenSettingsAreNull() {
- var model = createModel("url", "api_key", null);
+ var model = createModel("url", "api_key", null, null, null);
var overriddenModel = model.overrideWith(null);
MatcherAssert.assertThat(overriddenModel, sameInstance(model));
}
- public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit) {
- return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null);
+ public static CohereEmbeddingsModel createModel(
+ String url,
+ String apiKey,
+ @Nullable Integer tokenLimit,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
+ ) {
+ return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model, embeddingType);
}
- public static CohereEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions) {
- return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions);
+ public static CohereEmbeddingsModel createModel(
+ String url,
+ String apiKey,
+ @Nullable Integer tokenLimit,
+ @Nullable Integer dimensions,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
+ ) {
+ return createModel(url, apiKey, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model, embeddingType);
}
public static CohereEmbeddingsModel createModel(
@@ -59,13 +89,18 @@ public static CohereEmbeddingsModel createModel(
String apiKey,
CohereEmbeddingsTaskSettings taskSettings,
@Nullable Integer tokenLimit,
- @Nullable Integer dimensions
+ @Nullable Integer dimensions,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
) {
return new CohereEmbeddingsModel(
"id",
TaskType.TEXT_EMBEDDING,
"service",
- new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit),
+ new CohereEmbeddingsServiceSettings(
+ new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model),
+ embeddingType
+ ),
taskSettings,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
new file mode 100644
index 0000000000000..8daa5a27f9618
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
@@ -0,0 +1,167 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.embeddings;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
+import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase {
+ public static CohereEmbeddingsServiceSettings createRandom() {
+ var commonSettings = CohereServiceSettingsTests.createRandom();
+ var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null;
+
+ return new CohereEmbeddingsServiceSettings(commonSettings, embeddingType);
+ }
+
+ public void testFromMap() {
+ var url = "https://www.abc.com";
+ var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
+ var dims = 1536;
+ var maxInputTokens = 512;
+ var model = "model";
+ var serviceSettings = CohereEmbeddingsServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ ServiceFields.URL,
+ url,
+ ServiceFields.SIMILARITY,
+ similarity,
+ ServiceFields.DIMENSIONS,
+ dims,
+ ServiceFields.MAX_INPUT_TOKENS,
+ maxInputTokens,
+ CohereServiceSettings.MODEL,
+ model,
+ CohereEmbeddingsServiceSettings.EMBEDDING_TYPE,
+ CohereEmbeddingType.INT8.toString()
+ )
+ )
+ );
+
+ MatcherAssert.assertThat(
+ serviceSettings,
+ is(
+ new CohereEmbeddingsServiceSettings(
+ new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model),
+ CohereEmbeddingType.INT8
+ )
+ )
+ );
+ }
+
+ public void testFromMap_MissingEmbeddingType_DoesNotThrowException() {
+ var serviceSettings = CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of()));
+ assertNull(serviceSettings.getEmbeddingType());
+ }
+
+ public void testFromMap_EmptyEmbeddingType_ThrowsError() {
+ var thrownException = expectThrows(
+ ValidationException.class,
+ () -> CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "")))
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ containsString(
+ Strings.format(
+ "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
+ CohereEmbeddingsServiceSettings.EMBEDDING_TYPE
+ )
+ )
+ );
+ }
+
+ public void testFromMap_InvalidEmbeddingType_ThrowsError() {
+ var thrownException = expectThrows(
+ ValidationException.class,
+ () -> CohereEmbeddingsServiceSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc")))
+ );
+
+ MatcherAssert.assertThat(
+ thrownException.getMessage(),
+ is(
+ Strings.format(
+ "Validation Failed: 1: [service_settings] Invalid value [abc] received. [embedding_type] must be one of [float, int8];"
+ )
+ )
+ );
+ }
+
+ public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
+ var exception = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> CohereEmbeddingsServiceSettings.fromMap(
+ new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, List.of("abc")))
+ )
+ );
+
+ MatcherAssert.assertThat(
+ exception.getMessage(),
+ is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]")
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CohereEmbeddingsServiceSettings::new;
+ }
+
+ @Override
+ protected CohereEmbeddingsServiceSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CohereEmbeddingsServiceSettings mutateInstance(CohereEmbeddingsServiceSettings instance) throws IOException {
+ return createRandom();
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ List entries = new ArrayList<>();
+ entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
+ return new NamedWriteableRegistry(entries);
+ }
+
+ public static Map getServiceSettingsMap(
+ @Nullable String url,
+ @Nullable String model,
+ @Nullable CohereEmbeddingType embeddingType
+ ) {
+ var map = new HashMap<>(CohereServiceSettingsTests.getServiceSettingsMap(url, model));
+
+ if (embeddingType != null) {
+ map.put(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, embeddingType.toString());
+ }
+
+ return map;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
index a568a0eaa1e5a..cf16473bdb70f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.cohere.embeddings;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
@@ -19,7 +18,6 @@
import java.io.IOException;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.is;
@@ -27,18 +25,16 @@
public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase {
public static CohereEmbeddingsTaskSettings createRandom() {
- var model = randomBoolean() ? randomAlphaOfLength(15) : null;
var inputType = randomBoolean() ? randomFrom(InputType.values()) : null;
- var embeddingType = randomBoolean() ? randomFrom(CohereEmbeddingType.values()) : null;
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
- return new CohereEmbeddingsTaskSettings(model, inputType, embeddingType, truncation);
+ return new CohereEmbeddingsTaskSettings(inputType, truncation);
}
public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
MatcherAssert.assertThat(
CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())),
- is(new CohereEmbeddingsTaskSettings(null, null, null, null))
+ is(new CohereEmbeddingsTaskSettings(null, null))
);
}
@@ -47,18 +43,14 @@ public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
CohereEmbeddingsTaskSettings.fromMap(
new HashMap<>(
Map.of(
- CohereServiceFields.MODEL,
- "abc",
CohereEmbeddingsTaskSettings.INPUT_TYPE,
InputType.INGEST.toString(),
- CohereEmbeddingsTaskSettings.EMBEDDING_TYPE,
- CohereEmbeddingType.INT8.toString(),
CohereServiceFields.TRUNCATE,
CohereTruncation.END.toString()
)
)
),
- is(new CohereEmbeddingsTaskSettings("abc", InputType.INGEST, CohereEmbeddingType.INT8, CohereTruncation.END))
+ is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.END))
);
}
@@ -74,18 +66,6 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
);
}
- public void testFromMap_ReturnsFailure_WhenEmbeddingTypesAreNotValid() {
- var exception = expectThrows(
- ElasticsearchStatusException.class,
- () -> CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, List.of("abc"))))
- );
-
- MatcherAssert.assertThat(
- exception.getMessage(),
- is("field [embedding_type] is not of the expected type. The value [[abc]] cannot be converted to a [String]")
- );
- }
-
public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() {
var taskSettings = CohereEmbeddingsTaskSettings.fromMap(
new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
@@ -97,7 +77,7 @@ public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() {
public void testOverrideWith_UsesOverriddenSettings() {
var taskSettings = CohereEmbeddingsTaskSettings.fromMap(
- new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
+ new HashMap<>(Map.of(CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
);
var requestTaskSettings = CohereEmbeddingsTaskSettings.fromMap(
@@ -105,7 +85,7 @@ public void testOverrideWith_UsesOverriddenSettings() {
);
var overriddenTaskSettings = taskSettings.overrideWith(requestTaskSettings);
- MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings("model", null, null, CohereTruncation.START)));
+ MatcherAssert.assertThat(overriddenTaskSettings, is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.START)));
}
@Override
@@ -127,26 +107,13 @@ public static Map getTaskSettingsMapEmpty() {
return new HashMap<>();
}
- public static Map getTaskSettingsMap(
- @Nullable String model,
- @Nullable InputType inputType,
- @Nullable CohereEmbeddingType embeddingType,
- @Nullable CohereTruncation truncation
- ) {
+ public static Map getTaskSettingsMap(@Nullable InputType inputType, @Nullable CohereTruncation truncation) {
var map = new HashMap();
- if (model != null) {
- map.put(CohereServiceFields.MODEL, model);
- }
-
if (inputType != null) {
map.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
}
- if (embeddingType != null) {
- map.put(CohereEmbeddingsTaskSettings.EMBEDDING_TYPE, embeddingType.toString());
- }
-
if (truncation != null) {
map.put(CohereServiceFields.TRUNCATE, truncation.toString());
}
From 83caa87d5b8196fec76142d83d9ab3fd510f203d Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Jan 2024 10:57:13 -0500
Subject: [PATCH 10/13] Using separate named writeables for byte results and
floats
---
.../core/inference/results/ByteValue.java | 69 ------
...{EmbeddingValue.java => EmbeddingInt.java} | 7 +-
.../core/inference/results/FloatValue.java | 69 ------
.../core/inference/results/TextEmbedding.java | 18 ++
.../results/TextEmbeddingByteResults.java | 146 ++++++++++++
.../results/TextEmbeddingResults.java | 74 ++----
.../inference/results/TextEmbeddingUtils.java | 30 +++
.../InferenceNamedWriteablesProvider.java | 9 +-
.../CohereEmbeddingsResponseEntity.java | 74 +++---
.../HuggingFaceEmbeddingsResponseEntity.java | 2 +-
.../OpenAiEmbeddingsResponseEntity.java | 2 +-
.../inference/services/ServiceUtils.java | 35 +--
.../action/InferenceActionResponseTests.java | 6 +-
.../cohere/CohereActionCreatorTests.java | 4 +-
.../cohere/CohereEmbeddingsActionTests.java | 11 +-
.../HuggingFaceActionCreatorTests.java | 6 +-
.../openai/OpenAiActionCreatorTests.java | 10 +-
.../openai/OpenAiEmbeddingsActionTests.java | 4 +-
.../external/openai/OpenAiClientTests.java | 8 +-
.../CohereEmbeddingsResponseEntityTests.java | 34 +--
...gingFaceEmbeddingsResponseEntityTests.java | 20 +-
.../OpenAiEmbeddingsResponseEntityTests.java | 10 +-
.../inference/results/ByteValueTests.java | 35 ---
.../inference/results/FloatValueTests.java | 35 ---
.../TextEmbeddingByteResultsTests.java | 165 +++++++++++++
.../results/TextEmbeddingResultsTests.java | 219 +++---------------
.../inference/services/ServiceUtilsTests.java | 139 +++++++++++
.../services/cohere/CohereServiceTests.java | 4 +-
.../huggingface/HuggingFaceServiceTests.java | 4 +-
.../services/openai/OpenAiServiceTests.java | 4 +-
30 files changed, 668 insertions(+), 585 deletions(-)
delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{EmbeddingValue.java => EmbeddingInt.java} (59%)
delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java
delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
deleted file mode 100644
index 5017aa9be0d52..0000000000000
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteValue.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.inference.results;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.xcontent.XContentBuilder;
-
-import java.io.IOException;
-import java.util.Objects;
-
-public class ByteValue implements EmbeddingValue {
-
- public static final String NAME = "byte_value";
-
- private final Byte value;
-
- public ByteValue(Byte value) {
- this.value = value;
- }
-
- public ByteValue(StreamInput in) throws IOException {
- value = in.readByte();
- }
-
- @Override
- public String toString() {
- return value.toString();
- }
-
- @Override
- public Byte getValue() {
- return value;
- }
-
- @Override
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
- builder.value(value);
- return builder;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- ByteValue that = (ByteValue) o;
- return Objects.equals(value, that.value);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(value);
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeByte(value);
- }
-
- @Override
- public String getWriteableName() {
- return NAME;
- }
-}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java
similarity index 59%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java
index f3fd641db395d..05fc8a3cef1b6 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingValue.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java
@@ -7,9 +7,6 @@
package org.elasticsearch.xpack.core.inference.results;
-import org.elasticsearch.common.io.stream.NamedWriteable;
-import org.elasticsearch.xcontent.ToXContentFragment;
-
-public interface EmbeddingValue extends NamedWriteable, ToXContentFragment {
- Number getValue();
+public interface EmbeddingInt {
+ int getSize();
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
deleted file mode 100644
index 5c89720e54b2e..0000000000000
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatValue.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.inference.results;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.xcontent.XContentBuilder;
-
-import java.io.IOException;
-import java.util.Objects;
-
-public class FloatValue implements EmbeddingValue {
-
- public static final String NAME = "float_value";
-
- private final Float value;
-
- public FloatValue(Float value) {
- this.value = value;
- }
-
- public FloatValue(StreamInput in) throws IOException {
- value = in.readFloat();
- }
-
- @Override
- public String toString() {
- return value.toString();
- }
-
- @Override
- public Number getValue() {
- return value;
- }
-
- @Override
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
- builder.value(value);
- return builder;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- FloatValue that = (FloatValue) o;
- return Objects.equals(value, that.value);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(value);
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeFloat(value);
- }
-
- @Override
- public String getWriteableName() {
- return NAME;
- }
-}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java
new file mode 100644
index 0000000000000..a185c2938223e
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+public interface TextEmbedding {
+
+ /**
+ * Returns the first text embedding entry in the result list's array size.
+ * @return the size of the text embedding
+ * @throws IllegalStateException if the list of embeddings is empty
+ */
+ int getFirstEmbeddingSize() throws IllegalStateException;
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
new file mode 100644
index 0000000000000..4ffef36359589
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Writes a text embedding result in the follow json format
+ * {
+ * "text_embedding": [
+ * {
+ * "embedding": [
+ * 23
+ * ]
+ * },
+ * {
+ * "embedding": [
+ * -23
+ * ]
+ * }
+ * ]
+ * }
+ */
+public record TextEmbeddingByteResults(List embeddings) implements InferenceServiceResults, TextEmbedding {
+ public static final String NAME = "text_embedding_service_byte_results";
+ public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
+
+ public TextEmbeddingByteResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(Embedding::new));
+ }
+
+ @Override
+ public int getFirstEmbeddingSize() {
+ return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startArray(TEXT_EMBEDDING);
+ for (Embedding embedding : embeddings) {
+ embedding.toXContent(builder, params);
+ }
+ builder.endArray();
+ return builder;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeCollection(embeddings);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public List extends InferenceResults> transformToCoordinationFormat() {
+ return embeddings.stream()
+ .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
+ .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
+ .toList();
+ }
+
+ @Override
+ @SuppressWarnings("deprecation")
+ public List extends InferenceResults> transformToLegacyFormat() {
+ var legacyEmbedding = new LegacyTextEmbeddingResults(
+ embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList()
+ );
+
+ return List.of(legacyEmbedding);
+ }
+
+ public Map asMap() {
+ Map map = new LinkedHashMap<>();
+ map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));
+
+ return map;
+ }
+
+ public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt {
+ public static final String EMBEDDING = "embedding";
+
+ public Embedding(StreamInput in) throws IOException {
+ this(in.readCollectionAsImmutableList(StreamInput::readByte));
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeCollection(values, StreamOutput::writeByte);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ builder.startArray(EMBEDDING);
+ for (Byte value : values) {
+ builder.value(value);
+ }
+ builder.endArray();
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public String toString() {
+ return Strings.toString(this);
+ }
+
+ public Map asMap() {
+ return Map.of(EMBEDDING, values);
+ }
+
+ public List toFloats() {
+ return values.stream().map(Byte::floatValue).toList();
+ }
+
+ @Override
+ public int getSize() {
+ return values().size();
+ }
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
index fadffb2782018..75eb4ebc19902 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.core.inference.results;
-import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -19,10 +18,10 @@
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.stream.Collectors;
/**
@@ -42,7 +41,7 @@
* ]
* }
*/
-public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults {
+public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
@@ -55,11 +54,16 @@ public TextEmbeddingResults(StreamInput in) throws IOException {
this(
legacyTextEmbeddingResults.embeddings()
.stream()
- .map(embedding -> Embedding.ofFloats(embedding.values()))
+ .map(embedding -> new Embedding(embedding.values()))
.collect(Collectors.toList())
);
}
+ @Override
+ public int getFirstEmbeddingSize() {
+ return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
+ }
+
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray(TEXT_EMBEDDING);
@@ -83,7 +87,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> embedding.values.stream().mapToDouble(value -> value.getValue().doubleValue()).toArray())
+ .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.toList();
}
@@ -92,7 +96,7 @@ public List extends InferenceResults> transformToCoordinationFormat() {
@SuppressWarnings("deprecation")
public List extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
- embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloats())).toList()
+ embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList()
);
return List.of(legacyEmbedding);
@@ -105,52 +109,21 @@ public Map asMap() {
return map;
}
- public static class Embedding implements Writeable, ToXContentObject {
+ public record Embedding(List values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";
- public static Embedding ofFloats(List values) {
- return new Embedding(convertFloatsToEmbeddingValues(values));
- }
-
- public static Embedding ofBytes(List values) {
- List convertedValues = values.stream().map(ByteValue::new).collect(Collectors.toList());
-
- return new Embedding(convertedValues);
- }
-
- private final List values;
-
- public Embedding(List values) {
- this.values = values;
- }
-
public Embedding(StreamInput in) throws IOException {
- if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
- values = in.readNamedWriteableCollectionAsList(EmbeddingValue.class);
- } else {
- values = convertFloatsToEmbeddingValues(in.readCollectionAsImmutableList(StreamInput::readFloat));
- }
+ this(in.readCollectionAsImmutableList(StreamInput::readFloat));
}
- private static List convertFloatsToEmbeddingValues(List floats) {
- return floats.stream().map(FloatValue::new).collect(Collectors.toList());
- }
-
- public List toFloats() {
- return values.stream().map(value -> value.getValue().floatValue()).toList();
- }
-
- public List values() {
- return values;
+ @Override
+ public int getSize() {
+ return values.size();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
- out.writeNamedWriteableCollection(values);
- } else {
- out.writeCollection(toFloats(), StreamOutput::writeFloat);
- }
+ out.writeCollection(values, StreamOutput::writeFloat);
}
@Override
@@ -158,7 +131,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.startArray(EMBEDDING);
- for (EmbeddingValue value : values) {
+ for (Float value : values) {
builder.value(value);
}
builder.endArray();
@@ -172,19 +145,6 @@ public String toString() {
return Strings.toString(this);
}
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- Embedding embedding = (Embedding) o;
- return Objects.equals(values, embedding.values);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(values);
- }
-
public Map asMap() {
return Map.of(EMBEDDING, values);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java
new file mode 100644
index 0000000000000..02cb3b878c7fe
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import java.util.List;
+
+public class TextEmbeddingUtils {
+
+ /**
+ * Returns the first text embedding entry's array size.
+ * @param embeddings the list of embeddings
+ * @return the size of the text embedding
+ * @throws IllegalStateException if the list of embeddings is empty
+ */
+ public static int getFirstEmbeddingSize(List embeddings) throws IllegalStateException {
+ if (embeddings.isEmpty()) {
+ throw new IllegalStateException("Embeddings list is empty");
+ }
+
+ return embeddings.get(0).getSize();
+ }
+
+ private TextEmbeddingUtils() {}
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index f2f604e22bc24..982c33a08d1fc 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -14,11 +14,9 @@
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
-import org.elasticsearch.xpack.core.inference.results.ByteValue;
-import org.elasticsearch.xpack.core.inference.results.EmbeddingValue;
-import org.elasticsearch.xpack.core.inference.results.FloatValue;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
@@ -55,8 +53,9 @@ public static List getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
- namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, FloatValue.NAME, FloatValue::new));
- namedWriteables.add(new NamedWriteableRegistry.Entry(EmbeddingValue.class, ByteValue.NAME, ByteValue::new));
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new)
+ );
// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
index b9830b72de010..bd808c225d7e3 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
@@ -11,13 +11,12 @@
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.ByteValue;
-import org.elasticsearch.xpack.core.inference.results.EmbeddingValue;
-import org.elasticsearch.xpack.core.inference.results.FloatValue;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -36,11 +35,11 @@
public class CohereEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response";
- private static final Map> EMBEDDING_PARSERS = Map.of(
+ private static final Map> EMBEDDING_PARSERS = Map.of(
toLowerCase(CohereEmbeddingType.FLOAT),
- CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry,
+ CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
toLowerCase(CohereEmbeddingType.INT8),
- CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry
+ CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray
);
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
@@ -132,7 +131,7 @@ private static String supportedEmbeddingTypes() {
*
*
*/
- public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -148,15 +147,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp
return parseEmbeddingsObject(jsonParser);
} else if (token == XContentParser.Token.START_ARRAY) {
// if the request did not specify the embedding types then it will default to floats
- List embeddingList = XContentParserUtils.parseList(
- jsonParser,
- parser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(
- parser,
- CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry
- )
- );
-
- return new TextEmbeddingResults(embeddingList);
+ return parseFloatEmbeddingsArray(jsonParser);
} else {
throwUnknownToken(token, jsonParser);
}
@@ -167,7 +158,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp
}
}
- private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser) throws IOException {
+ private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser) throws IOException {
XContentParser.Token token;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
@@ -178,12 +169,7 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser)
}
parser.nextToken();
- var embeddingList = XContentParserUtils.parseList(
- parser,
- listParser -> CohereEmbeddingsResponseEntity.parseEmbeddingsArray(listParser, embeddingValueParser)
- );
-
- return new TextEmbeddingResults(embeddingList);
+ return embeddingValueParser.apply(parser);
}
}
@@ -195,29 +181,26 @@ private static TextEmbeddingResults parseEmbeddingsObject(XContentParser parser)
);
}
- private static TextEmbeddingResults.Embedding parseEmbeddingsArray(
- XContentParser parser,
- CheckedFunction parseEntry
- ) throws IOException {
- XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
- List embeddingValues = XContentParserUtils.parseList(parser, parseEntry);
+ private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
+ var embeddingList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
- return new TextEmbeddingResults.Embedding(embeddingValues);
+ return new TextEmbeddingByteResults(embeddingList);
}
- private static FloatValue parseEmbeddingFloatEntry(XContentParser parser) throws IOException {
- XContentParser.Token token = parser.currentToken();
- XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
- return new FloatValue(parser.floatValue());
+ private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
+ List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
+
+ return new TextEmbeddingByteResults.Embedding(embeddingValues);
}
- private static ByteValue parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
+ private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
var parsedByte = parser.shortValue();
checkByteBounds(parsedByte);
- return new ByteValue((byte) parsedByte);
+ return (byte) parsedByte;
}
private static void checkByteBounds(short value) {
@@ -226,5 +209,24 @@ private static void checkByteBounds(short value) {
}
}
+ private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException {
+ var embeddingList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry);
+
+ return new TextEmbeddingResults(embeddingList);
+ }
+
+ private static TextEmbeddingResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
+ List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry);
+
+ return new TextEmbeddingResults.Embedding(embeddingValues);
+ }
+
+ private static Float parseEmbeddingFloatEntry(XContentParser parser) throws IOException {
+ XContentParser.Token token = parser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
+ return parser.floatValue();
+ }
+
private CohereEmbeddingsResponseEntity() {}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
index b3706b318439b..b74b03891034f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java
@@ -149,7 +149,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValues = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList);
- return TextEmbeddingResults.Embedding.ofFloats(embeddingValues);
+ return new TextEmbeddingResults.Embedding(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
index 0640821d0b9e1..4926ba3f0ef6b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java
@@ -101,7 +101,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse
// if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
parser.skipChildren();
- return TextEmbeddingResults.Embedding.ofFloats(embeddingValues);
+ return new TextEmbeddingResults.Embedding(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 20912a3811af1..7029f9ca3bf56 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -18,6 +18,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.TextEmbedding;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
@@ -109,25 +110,6 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}
- public static String mustBeNonEmptyList(String settingName, String scope) {
- return Strings.format("[%s] Invalid value empty list. [%s] must be a non-empty list", scope, settingName);
- }
-
- public static String invalidType(String settingName, String scope, String invalidType, String invalidValue, String requiredType) {
- return Strings.format(
- "[%s] Invalid type [%s] received for value [%s]. [%s] must be type [%s]",
- scope,
- invalidType,
- invalidValue,
- settingName,
- requiredType
- );
- }
-
- public static String invalidValue(String settingName, String scope, String invalidValue, String validValue) {
- return invalidValue(settingName, scope, invalidValue, new String[] { validValue });
- }
-
public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) {
return Strings.format(
"[%s] Invalid value [%s] received. [%s] must be one of [%s]",
@@ -287,16 +269,11 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio
assert model.getTaskType() == TaskType.TEXT_EMBEDDING;
service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), listener.delegateFailureAndWrap((delegate, r) -> {
- if (r instanceof TextEmbeddingResults embeddingResults) {
- if (embeddingResults.embeddings().isEmpty()) {
- delegate.onFailure(
- new ElasticsearchStatusException(
- "Could not determine embedding size, no embeddings were returned in test call",
- RestStatus.BAD_REQUEST
- )
- );
- } else {
- delegate.onResponse(embeddingResults.embeddings().get(0).values().size());
+ if (r instanceof TextEmbedding embeddingResults) {
+ try {
+ delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
+ } catch (Exception e) {
+ delegate.onFailure(new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e));
}
} else {
delegate.onFailure(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
index 622cf51798609..759411cec1212 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
@@ -46,7 +46,7 @@ protected Writeable.Reader instanceReader() {
@Override
protected InferenceAction.Response createTestInstance() {
var result = switch (randomIntBetween(0, 2)) {
- case 0 -> TextEmbeddingResultsTests.createRandomFloatResults();
+ case 0 -> TextEmbeddingResultsTests.createRandomResults();
case 1 -> LegacyTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
default -> SparseEmbeddingResultsTests.createRandomResults();
};
@@ -90,7 +90,7 @@ public void testSerializesOpenAiAddedVersion_UsingSparseEmbeddingResult() throws
}
public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException {
- var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults();
+ var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), INFERENCE_MULTIPLE_INPUTS);
assertOnBWCObject(copy, instance, INFERENCE_MULTIPLE_INPUTS);
@@ -106,7 +106,7 @@ public void testSerializesMultipleInputsVersion_UsingSparseEmbeddingResult() thr
// Technically we should never see a text embedding result in the transport version of this test because support
// for it wasn't added until openai
public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException {
- var embeddingResults = TextEmbeddingResultsTests.createRandomFloatResults();
+ var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED);
assertOnBWCObject(copy, instance, ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
index acd0145cbad24..67a95265f093d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
@@ -39,7 +39,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -117,7 +117,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException {
var result = listener.actionGet(TIMEOUT);
- MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.123F, -0.123F)))));
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
MatcherAssert.assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
index 86ba29c95b4bc..501d5a5e42bfe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
@@ -44,8 +45,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBytes;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -122,7 +122,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var result = listener.actionGet(TIMEOUT);
- MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.123F, -0.123F)))));
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
MatcherAssert.assertThat(
@@ -199,7 +199,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I
var result = listener.actionGet(TIMEOUT);
- MatcherAssert.assertThat(result.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 0, (byte) -1)))));
+ MatcherAssert.assertThat(
+ result.asMap(),
+ is(TextEmbeddingByteResultsTests.buildExpectation(List.of(List.of((byte) 0, (byte) -1))))
+ );
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
MatcherAssert.assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
index c17b293c8bc35..95b69f1231e9d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java
@@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -331,7 +331,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(2));
{
@@ -389,7 +389,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloats(List.of(List.of(-0.0123F, 0.123F)))));
+ assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectation(List.of(List.of(-0.0123F, 0.123F)))));
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
index 2d7493a783c51..23b6f1ea2fbe3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java
@@ -33,7 +33,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap;
@@ -100,7 +100,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@@ -169,7 +169,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(2));
{
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -252,7 +252,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(2));
{
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -320,7 +320,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
index 2d5cea7bd981e..6bc8e2d61d579 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java
@@ -38,7 +38,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.equalTo;
@@ -104,7 +104,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
index 0c4b354e7171f..bb9612f01d8ff 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiClientTests.java
@@ -40,7 +40,7 @@
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequestTests.createRequest;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -104,7 +104,7 @@ public void testSend_SuccessfulResponse() throws IOException, URISyntaxException
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -155,7 +155,7 @@ public void testSend_SuccessfulResponse_WithoutUser() throws IOException, URISyn
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -205,7 +205,7 @@ public void testSend_SuccessfulResponse_WithoutOrganization() throws IOException
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
index 1d7af8577a722..76aecd997414c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
@@ -8,7 +8,9 @@
package org.elasticsearch.xpack.inference.external.response.cohere;
import org.apache.http.HttpResponse;
+import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -18,6 +20,7 @@
import java.nio.charset.StandardCharsets;
import java.util.List;
+import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
@@ -47,14 +50,15 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ InferenceServiceResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
+ MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingResults.class));
MatcherAssert.assertThat(
- parsedResults.embeddings(),
- is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ ((TextEmbeddingResults) parsedResults).embeddings(),
+ is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F))))
);
}
@@ -85,14 +89,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F))))
);
}
@@ -129,14 +133,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F))))
+ is(List.of(new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F))))
);
}
@@ -173,14 +177,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) -1, (byte) 0))))
+ is(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1, (byte) 0))))
);
}
@@ -213,7 +217,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -222,8 +226,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F))
+ new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)),
+ new TextEmbeddingResults.Embedding(List.of(-0.123F, 0.123F))
)
)
);
@@ -260,7 +264,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
}
""";
- TextEmbeddingResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse(
+ TextEmbeddingResults parsedResults = (TextEmbeddingResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -269,8 +273,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(-0.0018434525F, 0.01777649F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(-0.123F, 0.123F))
+ new TextEmbeddingResults.Embedding(List.of(-0.0018434525F, 0.01777649F)),
+ new TextEmbeddingResults.Embedding(List.of(-0.123F, 0.123F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
index d54bcd91c9eda..2b6e11fdfafa7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java
@@ -37,7 +37,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException {
@@ -57,7 +57,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws IOException {
@@ -83,8 +83,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
+ new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
+ new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
)
)
);
@@ -115,8 +115,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
+ new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
+ new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
)
)
);
@@ -254,7 +254,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException {
@@ -273,7 +273,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException {
@@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException {
@@ -309,7 +309,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
index 3e8b50591e3b6..010e990a3ce80 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java
@@ -49,7 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.014539449F, -0.015288644F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.0123F, -0.0123F))
+ new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)),
+ new TextEmbeddingResults.Embedding(List.of(0.0123F, -0.0123F))
)
)
);
@@ -261,7 +261,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(1.0F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(1.0F)))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
@@ -290,7 +290,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(4.0294965E10F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(4.0294965E10F)))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
deleted file mode 100644
index b89c3c9ab0ec9..0000000000000
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ByteValueTests.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.results;
-
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.inference.results.ByteValue;
-
-import java.io.IOException;
-
-public class ByteValueTests extends AbstractWireSerializingTestCase {
- public static ByteValue createRandom() {
- return new ByteValue(randomByte());
- }
-
- @Override
- protected Writeable.Reader instanceReader() {
- return ByteValue::new;
- }
-
- @Override
- protected ByteValue createTestInstance() {
- return createRandom();
- }
-
- @Override
- protected ByteValue mutateInstance(ByteValue instance) throws IOException {
- return null;
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
deleted file mode 100644
index 1d37c0d3ddee1..0000000000000
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/FloatValueTests.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.results;
-
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.inference.results.FloatValue;
-
-import java.io.IOException;
-
-public class FloatValueTests extends AbstractWireSerializingTestCase {
- public static FloatValue createRandom() {
- return new FloatValue(randomFloat());
- }
-
- @Override
- protected Writeable.Reader instanceReader() {
- return FloatValue::new;
- }
-
- @Override
- protected FloatValue createTestInstance() {
- return createRandom();
- }
-
- @Override
- protected FloatValue mutateInstance(FloatValue instance) throws IOException {
- return null;
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java
new file mode 100644
index 0000000000000..b9318db6ece34
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.results;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase {
+ public static TextEmbeddingByteResults createRandomResults() {
+ int embeddings = randomIntBetween(1, 10);
+ List embeddingResults = new ArrayList<>(embeddings);
+
+ for (int i = 0; i < embeddings; i++) {
+ embeddingResults.add(createRandomEmbedding());
+ }
+
+ return new TextEmbeddingByteResults(embeddingResults);
+ }
+
+ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
+ int columns = randomIntBetween(1, 10);
+ List floats = new ArrayList<>(columns);
+
+ for (int i = 0; i < columns; i++) {
+ floats.add(randomByte());
+ }
+
+ return new TextEmbeddingByteResults.Embedding(floats);
+ }
+
+ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
+ var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) 23))));
+
+ assertThat(
+ entity.asMap(),
+ is(
+ Map.of(
+ TextEmbeddingByteResults.TEXT_EMBEDDING,
+ List.of(Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)))
+ )
+ )
+ );
+
+ String xContentResult = Strings.toString(entity, true, true);
+ assertThat(xContentResult, is("""
+ {
+ "text_embedding" : [
+ {
+ "embedding" : [
+ 23
+ ]
+ }
+ ]
+ }"""));
+ }
+
+ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
+ var entity = new TextEmbeddingByteResults(
+ List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) 23)), new TextEmbeddingByteResults.Embedding(List.of((byte) 24)))
+
+ );
+
+ assertThat(
+ entity.asMap(),
+ is(
+ Map.of(
+ TextEmbeddingByteResults.TEXT_EMBEDDING,
+ List.of(
+ Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)),
+ Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 24))
+ )
+ )
+ )
+ );
+
+ String xContentResult = Strings.toString(entity, true, true);
+ assertThat(xContentResult, is("""
+ {
+ "text_embedding" : [
+ {
+ "embedding" : [
+ 23
+ ]
+ },
+ {
+ "embedding" : [
+ 24
+ ]
+ }
+ ]
+ }"""));
+ }
+
+ public void testTransformToCoordinationFormat() {
+ var results = new TextEmbeddingByteResults(
+ List.of(
+ new TextEmbeddingByteResults.Embedding(List.of((byte) 23, (byte) 24)),
+ new TextEmbeddingByteResults.Embedding(List.of((byte) 25, (byte) 26))
+ )
+ ).transformToCoordinationFormat();
+
+ assertThat(
+ results,
+ is(
+ List.of(
+ new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
+ TextEmbeddingByteResults.TEXT_EMBEDDING,
+ new double[] { 23F, 24F },
+ false
+ ),
+ new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
+ TextEmbeddingByteResults.TEXT_EMBEDDING,
+ new double[] { 25F, 26F },
+ false
+ )
+ )
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return TextEmbeddingByteResults::new;
+ }
+
+ @Override
+ protected TextEmbeddingByteResults createTestInstance() {
+ return createRandomResults();
+ }
+
+ @Override
+ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException {
+ // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
+ if (randomBoolean()) {
+ // -1 to remove at least one item from the list
+ int end = randomInt(instance.embeddings().size() - 1);
+ return new TextEmbeddingByteResults(instance.embeddings().subList(0, end));
+ } else {
+ List embeddings = new ArrayList<>(instance.embeddings());
+ embeddings.add(createRandomEmbedding());
+ return new TextEmbeddingByteResults(embeddings);
+ }
+ }
+
+ public static Map buildExpectation(List> embeddings) {
+ return Map.of(
+ TextEmbeddingByteResults.TEXT_EMBEDDING,
+ embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
index 9c154a20531f7..09d9894d98853 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
@@ -7,77 +7,31 @@
package org.elasticsearch.xpack.inference.results;
-import org.elasticsearch.TransportVersion;
-import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.xpack.core.inference.results.ByteValue;
-import org.elasticsearch.xpack.core.inference.results.FloatValue;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
-import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.function.Supplier;
-import static org.elasticsearch.TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingResultsTests extends AbstractBWCWireSerializationTestCase {
-
- private enum EmbeddingType {
- FLOAT,
- BYTE
- }
-
- private static Map> EMBEDDING_TYPE_BUILDERS = Map.of(
- EmbeddingType.FLOAT,
- TextEmbeddingResultsTests::createRandomFloatEmbedding,
- EmbeddingType.BYTE,
- TextEmbeddingResultsTests::createRandomByteEmbedding
- );
-
+public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase {
public static TextEmbeddingResults createRandomResults() {
- var embeddingType = randomFrom(EmbeddingType.values());
- var createFunction = EMBEDDING_TYPE_BUILDERS.get(embeddingType);
- assert createFunction != null : "the embeddings type map is missing a value from the EmbeddingType enum";
-
- return createRandomResults(createFunction);
- }
-
- public static TextEmbeddingResults createRandomFloatResults() {
- return createRandomResults(TextEmbeddingResultsTests::createRandomFloatEmbedding);
- }
-
- private static TextEmbeddingResults createRandomResults(Supplier creator) {
int embeddings = randomIntBetween(1, 10);
List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
- embeddingResults.add(creator.get());
+ embeddingResults.add(createRandomEmbedding());
}
return new TextEmbeddingResults(embeddingResults);
}
- private static TextEmbeddingResults.Embedding createRandomByteEmbedding() {
- int columns = randomIntBetween(1, 10);
- List bytes = new ArrayList<>(columns);
-
- for (int i = 0; i < columns; i++) {
- bytes.add(randomByte());
- }
-
- return TextEmbeddingResults.Embedding.ofBytes(bytes);
- }
-
- private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() {
+ private static TextEmbeddingResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
List floats = new ArrayList<>(columns);
@@ -85,34 +39,19 @@ private static TextEmbeddingResults.Embedding createRandomFloatEmbedding() {
floats.add(randomFloat());
}
- return TextEmbeddingResults.Embedding.ofFloats(floats);
+ return new TextEmbeddingResults.Embedding(floats);
}
- public static Map buildExpectationFloats(List> embeddings) {
- return Map.of(
- TextEmbeddingResults.TEXT_EMBEDDING,
- embeddings.stream()
- .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(FloatValue::new).toList()))
- .toList()
- );
- }
+ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
+ var entity = new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(List.of(0.1F))));
- public static Map buildExpectationBytes(List> embeddings) {
- return Map.of(
- TextEmbeddingResults.TEXT_EMBEDDING,
- embeddings.stream()
- .map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding.stream().map(ByteValue::new).toList()))
- .toList()
+ assertThat(
+ entity.asMap(),
+ is(Map.of(TextEmbeddingResults.TEXT_EMBEDDING, List.of(Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)))))
);
- }
-
- public void testToXContent_CreatesTheRightFormatForASingleEmbedding() {
- var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F))));
-
- MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F)))));
String xContentResult = Strings.toString(entity, true, true);
- MatcherAssert.assertThat(xContentResult, is("""
+ assertThat(xContentResult, is("""
{
"text_embedding" : [
{
@@ -124,33 +63,27 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() {
}"""));
}
- public void testToXContent_CreatesTheRightFormatForASingleEmbedding_ForBytes() {
- var entity = new TextEmbeddingResults(List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12))));
-
- MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12)))));
-
- String xContentResult = Strings.toString(entity, true, true);
- MatcherAssert.assertThat(xContentResult, is("""
- {
- "text_embedding" : [
- {
- "embedding" : [
- 12
- ]
- }
- ]
- }"""));
- }
-
- public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() {
+ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
var entity = new TextEmbeddingResults(
- List.of(TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F)), TextEmbeddingResults.Embedding.ofFloats(List.of(0.2F)))
+ List.of(new TextEmbeddingResults.Embedding(List.of(0.1F)), new TextEmbeddingResults.Embedding(List.of(0.2F)))
+
);
- MatcherAssert.assertThat(entity.asMap(), is(buildExpectationFloats(List.of(List.of(0.1F), List.of(0.2F)))));
+ assertThat(
+ entity.asMap(),
+ is(
+ Map.of(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ List.of(
+ Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.1F)),
+ Map.of(TextEmbeddingResults.Embedding.EMBEDDING, List.of(0.2F))
+ )
+ )
+ )
+ );
String xContentResult = Strings.toString(entity, true, true);
- MatcherAssert.assertThat(xContentResult, is("""
+ assertThat(xContentResult, is("""
{
"text_embedding" : [
{
@@ -167,40 +100,12 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() {
}"""));
}
- public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings_ForBytes() {
- var entity = new TextEmbeddingResults(
- List.of(TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12)), TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 34)))
- );
-
- MatcherAssert.assertThat(entity.asMap(), is(buildExpectationBytes(List.of(List.of((byte) 12), List.of((byte) 34)))));
-
- String xContentResult = Strings.toString(entity, true, true);
- MatcherAssert.assertThat(xContentResult, is("""
- {
- "text_embedding" : [
- {
- "embedding" : [
- 12
- ]
- },
- {
- "embedding" : [
- 34
- ]
- }
- ]
- }"""));
- }
-
public void testTransformToCoordinationFormat() {
var results = new TextEmbeddingResults(
- List.of(
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.1F, 0.2F)),
- TextEmbeddingResults.Embedding.ofFloats(List.of(0.3F, 0.4F))
- )
+ List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F)))
).transformToCoordinationFormat();
- MatcherAssert.assertThat(
+ assertThat(
results,
is(
List.of(
@@ -219,49 +124,6 @@ public void testTransformToCoordinationFormat() {
);
}
- public void testTransformToCoordinationFormat_FromBytes() {
- var results = new TextEmbeddingResults(
- List.of(
- TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 12, (byte) 34)),
- TextEmbeddingResults.Embedding.ofBytes(List.of((byte) 56, (byte) -78))
- )
- ).transformToCoordinationFormat();
-
- MatcherAssert.assertThat(
- results,
- is(
- List.of(
- new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
- TextEmbeddingResults.TEXT_EMBEDDING,
- new double[] { 12F, 34F },
- false
- ),
- new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
- TextEmbeddingResults.TEXT_EMBEDDING,
- new double[] { 56F, -78F },
- false
- )
- )
- )
- );
- }
-
- public void testSerializesToFloats_WhenVersionIsPriorToByteSupport() throws IOException {
- var instance = createRandomResults(TextEmbeddingResultsTests::createRandomByteEmbedding);
- var modifiedForOlderVersion = mutateInstanceForVersion(instance, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
-
- var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
- assertOnBWCObject(copy, modifiedForOlderVersion, ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED);
- }
-
- @Override
- protected NamedWriteableRegistry getNamedWriteableRegistry() {
- List entries = new ArrayList<>();
- entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
- entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
- return new NamedWriteableRegistry(entries);
- }
-
@Override
protected Writeable.Reader instanceReader() {
return TextEmbeddingResults::new;
@@ -281,26 +143,15 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr
return new TextEmbeddingResults(instance.embeddings().subList(0, end));
} else {
List embeddings = new ArrayList<>(instance.embeddings());
- embeddings.add(createRandomFloatEmbedding());
+ embeddings.add(createRandomEmbedding());
return new TextEmbeddingResults(embeddings);
}
}
- @Override
- protected TextEmbeddingResults mutateInstanceForVersion(TextEmbeddingResults instance, TransportVersion version) {
- if (version.before(TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED)) {
- return convertToFloatEmbeddings(instance);
- }
-
- return instance;
- }
-
- public TextEmbeddingResults convertToFloatEmbeddings(TextEmbeddingResults results) {
- var floatEmbeddings = results.embeddings()
- .stream()
- .map(embedding -> TextEmbeddingResults.Embedding.ofFloats(embedding.toFloats()))
- .toList();
-
- return new TextEmbeddingResults(floatEmbeddings);
+ public static Map buildExpectation(List> embeddings) {
+ return Map.of(
+ TextEmbeddingResults.TEXT_EMBEDDING,
+ embeddings.stream().map(embedding -> Map.of(TextEmbeddingResults.Embedding.EMBEDDING, embedding)).toList()
+ );
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
index c72b161941ad2..b935c5a8c64b3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
@@ -8,24 +8,45 @@
package org.elasticsearch.xpack.inference.services;
import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
+import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public class ServiceUtilsTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30);
+
public void testRemoveAsTypeWithTheCorrectType() {
Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 1.0));
@@ -237,6 +258,124 @@ public void testExtractOptionalString_AddsException_WhenFieldIsEmpty() {
assertThat(validation.validationErrors().get(0), is("[scope] Invalid value empty string. [key] must be a non-empty string"));
}
+ public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() {
+ var validation = new ValidationException();
+ Map map = modifiableMap(Map.of("key", "value"));
+ var createdEnum = extractOptionalEnum(map, "abc", "scope", InputType::fromString, InputType.values(), validation);
+
+ assertNull(createdEnum);
+ assertTrue(validation.validationErrors().isEmpty());
+ assertThat(map.size(), is(1));
+ }
+
+ public void testExtractOptionalEnum_ReturnsNullAndAddsException_WhenAnInvalidValueExists() {
+ var validation = new ValidationException();
+ Map map = modifiableMap(Map.of("key", "invalid_value"));
+ var createdEnum = extractOptionalEnum(map, "key", "scope", InputType::fromString, InputType.values(), validation);
+
+ assertNull(createdEnum);
+ assertFalse(validation.validationErrors().isEmpty());
+ assertTrue(map.isEmpty());
+ assertThat(
+ validation.validationErrors().get(0),
+ is("[scope] Invalid value [invalid_value] received. [key] must be one of [ingest, search]")
+ );
+ }
+
+ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() {
+ var service = mock(InferenceService.class);
+
+ var model = mock(Model.class);
+ when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[3];
+ listener.onResponse(new TextEmbeddingResults(List.of()));
+
+ return Void.TYPE;
+ }).when(service).infer(any(), any(), any(), any());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ getEmbeddingSize(model, service, listener);
+
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is("Could not determine embedding size"));
+ assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty"));
+ }
+
+ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmpty() {
+ var service = mock(InferenceService.class);
+
+ var model = mock(Model.class);
+ when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[3];
+ listener.onResponse(new TextEmbeddingByteResults(List.of()));
+
+ return Void.TYPE;
+ }).when(service).infer(any(), any(), any(), any());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ getEmbeddingSize(model, service, listener);
+
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is("Could not determine embedding size"));
+ assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty"));
+ }
+
+ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() {
+ var service = mock(InferenceService.class);
+
+ var model = mock(Model.class);
+ when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
+
+ var textEmbedding = TextEmbeddingResultsTests.createRandomResults();
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[3];
+ listener.onResponse(textEmbedding);
+
+ return Void.TYPE;
+ }).when(service).infer(any(), any(), any(), any());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ getEmbeddingSize(model, service, listener);
+
+ var size = listener.actionGet(TIMEOUT);
+
+ assertThat(size, is(textEmbedding.embeddings().get(0).values().size()));
+ }
+
+ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() {
+ var service = mock(InferenceService.class);
+
+ var model = mock(Model.class);
+ when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
+
+ var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults();
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[3];
+ listener.onResponse(textEmbedding);
+
+ return Void.TYPE;
+ }).when(service).infer(any(), any(), any(), any());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ getEmbeddingSize(model, service, listener);
+
+ var size = listener.actionGet(TIMEOUT);
+
+ assertThat(size, is(textEmbedding.embeddings().get(0).values().size()));
+ }
+
private static Map modifiableMap(Map aMap) {
return new HashMap<>(aMap);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index dc3c8eafb46ac..0250e08a48452 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -50,7 +50,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
@@ -749,7 +749,7 @@ public void testInfer_SendsRequest() throws IOException {
var result = listener.actionGet(TIMEOUT);
- MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.123F, -0.123F)))));
+ MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.123F, -0.123F)))));
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
MatcherAssert.assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
index aac50b8645993..a76cce41b4fe4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
@@ -47,7 +47,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@@ -496,7 +496,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(-0.0123F, 0.0123F)))));
+ assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(-0.0123F, 0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index ab4ee881a224e..394286ee5287b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -46,7 +46,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloats;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceSettingsTests.getServiceSettingsMap;
@@ -717,7 +717,7 @@ public void testInfer_SendsRequest() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), Matchers.is(buildExpectationFloats(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
From 8102408efb3501ad75a8973e8e04247fd89a38fc Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Jan 2024 11:17:42 -0500
Subject: [PATCH 11/13] Fixing a few comments and adding tests
---
.../cohere/CohereEmbeddingsRequestEntity.java | 3 +-
.../services/cohere/CohereServiceFields.java | 2 -
.../cohere/CohereServiceSettings.java | 2 +-
.../CohereEmbeddingsResponseEntityTests.java | 38 +++++++++++++++++++
.../CohereEmbeddingsTaskSettingsTests.java | 3 +-
5 files changed, 43 insertions(+), 5 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
index a7d8743359f39..a0b5444ee45e4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
@@ -52,7 +53,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(TEXTS_FIELD, input);
if (model != null) {
- builder.field(CohereServiceFields.MODEL, model);
+ builder.field(CohereServiceSettings.MODEL, model);
}
if (taskSettings.inputType() != null) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
index ccfe1cb2593c6..807520637f971 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
@@ -8,7 +8,5 @@
package org.elasticsearch.xpack.inference.services.cohere;
public class CohereServiceFields {
- public static final String MODEL = "model";
public static final String TRUNCATE = "truncate";
-
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
index f03371593a340..7964741d90343 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
@@ -46,7 +46,7 @@ public static CohereServiceSettings fromMap(Map map) {
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
- String model = extractOptionalString(map, MODEL, ModelConfigurations.TASK_SETTINGS, validationException);
+ String model = extractOptionalString(map, MODEL, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
index 76aecd997414c..f04715be0838f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
@@ -188,6 +188,44 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir
);
}
+ public void testFromResponse_ParsesBytes() throws IOException {
+ String responseJson = """
+ {
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+ "texts": [
+ "hello"
+ ],
+ "embeddings": {
+ "int8": [
+ [
+ -1,
+ 0
+ ]
+ ]
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 1
+ }
+ },
+ "response_type": "embeddings_floats"
+ }
+ """;
+
+ TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(
+ parsedResults.embeddings(),
+ is(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1, (byte) 0))))
+ );
+ }
+
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
String responseJson = """
{
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
index cf16473bdb70f..164d3998f138f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java
@@ -13,6 +13,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields;
+import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import org.hamcrest.MatcherAssert;
@@ -68,7 +69,7 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
public void testOverrideWith_KeepsOriginalValuesWhenOverridesAreNull() {
var taskSettings = CohereEmbeddingsTaskSettings.fromMap(
- new HashMap<>(Map.of(CohereServiceFields.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
+ new HashMap<>(Map.of(CohereServiceSettings.MODEL, "model", CohereServiceFields.TRUNCATE, CohereTruncation.END.toString()))
);
var overriddenTaskSettings = taskSettings.overrideWith(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS);
From 6c8343ffa12edee0abccc3d52a0c42512e012dbd Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Jan 2024 11:46:12 -0500
Subject: [PATCH 12/13] Fixing mutation issue
---
.../inference/services/cohere/CohereServiceSettingsTests.java | 2 +-
.../cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
index 8f829c11c0a11..6f47d5c74d81c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
@@ -141,7 +141,7 @@ protected CohereServiceSettings createTestInstance() {
@Override
protected CohereServiceSettings mutateInstance(CohereServiceSettings instance) throws IOException {
- return createRandomWithNonNullUrl();
+ return null;
}
public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
index 8daa5a27f9618..e0b29ce9c34da 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java
@@ -140,7 +140,7 @@ protected CohereEmbeddingsServiceSettings createTestInstance() {
@Override
protected CohereEmbeddingsServiceSettings mutateInstance(CohereEmbeddingsServiceSettings instance) throws IOException {
- return createRandom();
+ return null;
}
@Override
From de682b5568463085df6c115b60e8eeb576baa239 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Jan 2024 15:16:15 -0500
Subject: [PATCH 13/13] Removing cohere service settings from named writeable
registry
---
.../xpack/inference/InferenceNamedWriteablesProvider.java | 3 ---
.../cohere/embeddings/CohereEmbeddingsServiceSettings.java | 4 ++--
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index 982c33a08d1fc..c23e245b5696c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -98,9 +98,6 @@ public static List getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
- namedWriteables.add(
- new NamedWriteableRegistry.Entry(CohereServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
- );
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
index f8400bc168d69..5327bcbcf22dd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
@@ -57,7 +57,7 @@ public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, @Nu
}
public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException {
- commonSettings = in.readNamedWriteable(CohereServiceSettings.class);
+ commonSettings = new CohereServiceSettings(in);
embeddingType = in.readOptionalEnum(CohereEmbeddingType.class);
}
@@ -92,7 +92,7 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeNamedWriteable(commonSettings);
+ commonSettings.writeTo(out);
out.writeOptionalEnum(embeddingType);
}