From 31f1371762474535459fa18bf8ed729356c150f7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 30 Jan 2024 10:03:49 -0500 Subject: [PATCH] Adding request source for cohere --- .../request/cohere/CohereEmbeddingsRequest.java | 1 + .../external/request/cohere/CohereUtils.java | 9 +++++++++ .../cohere/CohereEmbeddingsActionTests.java | 9 +++++++++ .../cohere/CohereEmbeddingsRequestTests.java | 16 ++++++++++++++++ 4 files changed, 35 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java index 8cacbd0f16aa..30427aaa3586 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java @@ -62,6 +62,7 @@ public HttpRequest createHttpRequest() { httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + httpPost.setHeader(CohereUtils.createRequestSourceHeader()); return new HttpRequest(httpPost, getInferenceEntityId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java index f8ccd91d4e3d..e54328df1dbf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java @@ -7,10 +7,19 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; + public class CohereUtils { public static final String HOST = "api.cohere.ai"; public static final String VERSION_1 = "v1"; public static final String EMBEDDINGS_PATH = "embed"; + public static final String REQUEST_SOURCE_HEADER = "Request-Source"; + public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + + public static Header createRequestSourceHeader() { + return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); + } private CohereUtils() {} } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 501d5a5e42bf..7fd33f7bba58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; @@ -130,6 +131,10 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { equalTo(XContentType.JSON.mediaType()) ); MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), + equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( @@ -210,6 +215,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I equalTo(XContentType.JSON.mediaType()) ); MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), + equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java index df61417ffff9..d3783f6fed76 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java @@ -44,6 +44,10 @@ public void testCreateRequest_UrlDefined() throws URISyntaxException, IOExceptio MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc")))); @@ -71,6 +75,10 @@ public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOE MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); MatcherAssert.assertThat( @@ -114,6 +122,10 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); MatcherAssert.assertThat( @@ -157,6 +169,10 @@ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOExcept MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none")));