diff --git a/docs/changelog/117589.yaml b/docs/changelog/117589.yaml new file mode 100644 index 0000000000000..e6880fd9477b5 --- /dev/null +++ b/docs/changelog/117589.yaml @@ -0,0 +1,5 @@ +pr: 117589 +summary: "Add Inference Unified API for chat completions for OpenAI" +area: Machine Learning +type: enhancement +issues: [] diff --git a/muted-tests.yml b/muted-tests.yml index 07072e9743c98..7c5df966e6bd6 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -239,9 +239,6 @@ tests: - class: org.elasticsearch.packaging.test.ConfigurationTests method: test30SymlinkedDataPath issue: https://github.com/elastic/elasticsearch/issues/118111 -- class: org.elasticsearch.datastreams.ResolveClusterDataStreamIT - method: testClusterResolveWithDataStreamsUsingAlias - issue: https://github.com/elastic/elasticsearch/issues/118124 - class: org.elasticsearch.packaging.test.KeystoreManagementTests method: test30KeystorePasswordFromFile issue: https://github.com/elastic/elasticsearch/issues/118123 @@ -266,6 +263,15 @@ tests: - class: org.elasticsearch.xpack.inference.DefaultEndPointsIT method: testInferDeploysDefaultRerank issue: https://github.com/elastic/elasticsearch/issues/118184 +- class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT + method: testCancelRequestWhenFailingFetchingPages + issue: https://github.com/elastic/elasticsearch/issues/118193 +- class: org.elasticsearch.packaging.test.MemoryLockingTests + method: test20MemoryLockingEnabled + issue: https://github.com/elastic/elasticsearch/issues/118195 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test42AutoconfigurationNotTriggeredWhenNodeCannotBecomeMaster + issue: https://github.com/elastic/elasticsearch/issues/118196 # Examples: # diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 2e78cc6f516b1..6a5aa2943de92 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.xcontent.ToXContent; +import java.util.Collections; import java.util.Iterator; public enum ChunkedToXContentHelper { @@ -53,6 +54,14 @@ public static Iterator field(String name, String value) { return Iterators.single(((builder, params) -> builder.field(name, value))); } + public static Iterator optionalField(String name, String value) { + if (value == null) { + return Collections.emptyIterator(); + } else { + return field(name, value); + } + } + /** * Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link * Iterators#single}, but still useful because it avoids any type ambiguity. diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 4497254aad1f0..c2d690d8160ac 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -112,6 +112,23 @@ void infer( ); /** + * Perform completion inference on the model using the unified schema. + * + * @param model The model + * @param request Parameters for the request + * @param timeout The timeout for the request + * @param listener Inference result listener + */ + void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ); + + /** + * Chunk long text. + * * @param model The model * @param query Inference query, mainly for re-ranking * @param input Inference input diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index b0e5bababbbc0..fcb8ea7213795 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -38,6 +38,10 @@ public static TaskType fromString(String name) { } public static TaskType fromStringOrStatusException(String name) { + if (name == null) { + throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST); + } + try { TaskType taskType = TaskType.fromString(name); return Objects.requireNonNull(taskType); diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java new file mode 100644 index 0000000000000..e596be626b518 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -0,0 +1,425 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public record UnifiedCompletionRequest( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable List stop, + @Nullable Float temperature, + @Nullable ToolChoice toolChoice, + @Nullable List tools, + @Nullable Float topP +) implements Writeable { + + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + UnifiedCompletionRequest.class.getSimpleName(), + args -> new UnifiedCompletionRequest( + (List) args[0], + (String) args[1], + (Long) args[2], + (List) args[3], + (Float) args[4], + (ToolChoice) args[5], + (List) args[6], + (Float) args[7] + ) + ); + + static { + PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); + PARSER.declareString(optionalConstructorArg(), new ParseField("model")); + PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); + PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseToolChoice(p), + new ParseField("tool_choice"), + ObjectParser.ValueType.OBJECT_OR_STRING + ); + PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new), + new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new) + ); + } + + public static UnifiedCompletionRequest of(List messages) { + return new UnifiedCompletionRequest(messages, null, null, null, null, null, null, null); + } + + public UnifiedCompletionRequest(StreamInput in) throws IOException { + this( + in.readCollectionAsImmutableList(Message::new), + in.readOptionalString(), + in.readOptionalVLong(), + in.readOptionalStringCollectionAsList(), + in.readOptionalFloat(), + in.readOptionalNamedWriteable(ToolChoice.class), + in.readOptionalCollectionAsList(Tool::new), + in.readOptionalFloat() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(messages); + out.writeOptionalString(model); + out.writeOptionalVLong(maxCompletionTokens); + out.writeOptionalStringCollection(stop); + out.writeOptionalFloat(temperature); + out.writeOptionalNamedWriteable(toolChoice); + out.writeOptionalCollection(tools); + out.writeOptionalFloat(topP); + } + + public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) + implements + Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Message.class.getSimpleName(), + args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) + ); + + static { + PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareString(constructorArg(), new ParseField("role")); + PARSER.declareString(optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); + PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls")); + } + + private static Content parseContent(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + var parsedContentObjects = XContentParserUtils.parseList(parser, (p) -> ContentObject.PARSER.apply(p, null)); + return new ContentObjects(parsedContentObjects); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ContentString.of(parser); + } + + throw new XContentParseException("Expected an array start token or a value string token but found token [" + token + "]"); + } + + public Message(StreamInput in) throws IOException { + this( + in.readNamedWriteable(Content.class), + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalCollectionAsList(ToolCall::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(content); + out.writeString(role); + out.writeOptionalString(name); + out.writeOptionalString(toolCallId); + out.writeOptionalCollection(toolCalls); + } + } + + public record ContentObjects(List contentObjects) implements Content, NamedWriteable { + + public static final String NAME = "content_objects"; + + public ContentObjects(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(ContentObject::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(contentObjects); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record ContentObject(String text, String type) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ContentObject.class.getSimpleName(), + args -> new ContentObject((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("text")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ContentObject(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(text); + out.writeString(type); + } + + public String toString() { + return text + ":" + type; + } + + } + + public record ContentString(String content) implements Content, NamedWriteable { + public static final String NAME = "content_string"; + + public static ContentString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ContentString(content); + } + + public ContentString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(content); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public String toString() { + return content; + } + } + + public record ToolCall(String id, FunctionField function, String type) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolCall.class.getSimpleName(), + args -> new ToolCall((String) args[0], (FunctionField) args[1], (String) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("id")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ToolCall(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + function.writeTo(out); + out.writeString(type); + } + + public record FunctionField(String arguments, String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_call_function_field", + args -> new FunctionField((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("arguments")); + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(arguments); + out.writeString(name); + } + } + } + + private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return ToolChoiceObject.PARSER.apply(parser, null); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ToolChoiceString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); + } + + public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {} + + public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable { + + public static final String NAME = "tool_choice_object"; + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolChoiceObject.class.getSimpleName(), + args -> new ToolChoiceObject((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public ToolChoiceObject(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public record FunctionField(String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_choice_function_field", + args -> new FunctionField((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + } + } + } + + public record ToolChoiceString(String value) implements ToolChoice, NamedWriteable { + public static final String NAME = "tool_choice_string"; + + public static ToolChoiceString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ToolChoiceString(content); + } + + public ToolChoiceString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record Tool(String type, FunctionField function) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Tool.class.getSimpleName(), + args -> new Tool((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public Tool(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + public record FunctionField( + @Nullable String description, + String name, + @Nullable Map parameters, + @Nullable Boolean strict + ) implements Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_function_field", + args -> new FunctionField((String) args[0], (String) args[1], (Map) args[2], (Boolean) args[3]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), new ParseField("description")); + PARSER.declareString(constructorArg(), new ParseField("name")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("parameters")); + PARSER.declareBoolean(optionalConstructorArg(), new ParseField("strict")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readString(), in.readGenericMap(), in.readOptionalBoolean()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(description); + out.writeString(name); + out.writeGenericMap(parameters); + out.writeOptionalBoolean(strict); + } + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index b4f91f68b8bb7..7cd7bce4db187 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -9,6 +9,8 @@ package org.elasticsearch.test; +import com.carrotsearch.randomizedtesting.RandomizedTest; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.admin.cluster.remote.RemoteInfoRequest; @@ -108,6 +110,11 @@ public final void startClusters() throws Exception { MockTransportService.TestPlugin.class, getTestTransportPlugin() ); + // We are going to initialize multiple clusters concurrently, but there is a race condition around the lazy initialization of test + // groups in GroupEvaluator across multiple threads. See https://github.com/randomizedtesting/randomizedtesting/issues/311. + // Calling isNightly before parallelizing is enough to work around that issue. + @SuppressWarnings("unused") + boolean nightly = RandomizedTest.isNightly(); runInParallel(clusterAliases.size(), i -> { String clusterAlias = clusterAliases.get(i); final String clusterName = clusterAlias.equals(LOCAL_CLUSTER) ? "main-cluster" : clusterAlias; diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index d983fc854bdfd..a71f61740e17b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1205,10 +1205,30 @@ public static SecureString randomSecureStringOfLength(int codeUnits) { return new SecureString(randomAlpha.toCharArray()); } - public static String randomNullOrAlphaOfLength(int codeUnits) { + public static String randomAlphaOfLengthOrNull(int codeUnits) { return randomBoolean() ? null : randomAlphaOfLength(codeUnits); } + public static Long randomLongOrNull() { + return randomBoolean() ? null : randomLong(); + } + + public static Long randomPositiveLongOrNull() { + return randomBoolean() ? null : randomNonNegativeLong(); + } + + public static Integer randomIntOrNull() { + return randomBoolean() ? null : randomInt(); + } + + public static Integer randomPositiveIntOrNull() { + return randomBoolean() ? null : randomNonNegativeInt(); + } + + public static Float randomFloatOrNull() { + return randomBoolean() ? null : randomFloat(); + } + /** * Creates a valid random identifier such as node id or index name */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java new file mode 100644 index 0000000000000..e426574c52ce6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; + +public abstract class BaseInferenceActionRequest extends ActionRequest { + + public BaseInferenceActionRequest() { + super(); + } + + public BaseInferenceActionRequest(StreamInput in) throws IOException { + super(in); + } + + public abstract boolean isStreaming(); + + public abstract TaskType getTaskType(); + + public abstract String getInferenceEntityId(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index a19edd5a08162..f88909ba4208e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -54,7 +53,7 @@ public InferenceAction() { super(NAME); } - public static class Request extends ActionRequest { + public static class Request extends BaseInferenceActionRequest { public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30); public static final ParseField INPUT = new ParseField("input"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java new file mode 100644 index 0000000000000..8d121463fb465 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class UnifiedCompletionAction extends ActionType { + public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); + public static final String NAME = "cluster:monitor/xpack/inference/unified"; + + public UnifiedCompletionAction() { + super(NAME); + } + + public static class Request extends BaseInferenceActionRequest { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) + throws IOException { + var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); + return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); + } + + private final String inferenceEntityId; + private final TaskType taskType; + private final UnifiedCompletionRequest unifiedCompletionRequest; + private final TimeValue timeout; + + public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) { + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.taskType = Objects.requireNonNull(taskType); + this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); + this.timeout = Objects.requireNonNull(timeout); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + this.taskType = TaskType.fromStream(in); + this.unifiedCompletionRequest = new UnifiedCompletionRequest(in); + this.timeout = in.readTimeValue(); + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public UnifiedCompletionRequest getUnifiedCompletionRequest() { + return unifiedCompletionRequest; + } + + /** + * The Unified API only supports streaming so we always return true here. + * @return true + */ + public boolean isStreaming() { + return true; + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public ActionRequestValidationException validate() { + if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be null"); + return e; + } + + if (unifiedCompletionRequest.messages().isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be an empty array"); + return e; + } + + if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [taskType] must be [completion]"); + return e; + } + + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + unifiedCompletionRequest.writeTo(out); + out.writeTimeValue(timeout); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId) + && taskType == request.taskType + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) + && Objects.equals(timeout, request.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest, timeout); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java new file mode 100644 index 0000000000000..90038c67036c4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -0,0 +1,329 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Flow; + +/** + * Chat Completion results that only contain a Flow.Publisher. + */ +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { + + public static final String NAME = "chat_completion_chunk"; + public static final String MODEL_FIELD = "model"; + public static final String OBJECT_FIELD = "object"; + public static final String USAGE_FIELD = "usage"; + public static final String INDEX_FIELD = "index"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_NAME_FIELD = "name"; + public static final String FUNCTION_ARGUMENTS_FIELD = "arguments"; + public static final String FUNCTION_FIELD = "function"; + public static final String CHOICES_FIELD = "choices"; + public static final String DELTA_FIELD = "delta"; + public static final String CONTENT_FIELD = "content"; + public static final String REFUSAL_FIELD = "refusal"; + public static final String ROLE_FIELD = "role"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TYPE_FIELD = "type"; + + @Override + public boolean isStreaming() { + return true; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException("Not implemented"); + } + + public record Results(Deque chunks) implements ChunkedToXContent { + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); + } + } + + public static class ChatCompletionChunk implements ChunkedToXContent { + private final String id; + + public String getId() { + return id; + } + + public List getChoices() { + return choices; + } + + public String getModel() { + return model; + } + + public String getObject() { + return object; + } + + public Usage getUsage() { + return usage; + } + + private final List choices; + private final String model; + private final String object; + private final ChatCompletionChunk.Usage usage; + + public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { + this.id = id; + this.choices = choices; + this.model = model; + this.object = object; + this.usage = usage; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + + Iterator choicesIterator = Collections.emptyIterator(); + if (choices != null) { + choicesIterator = Iterators.concat( + ChunkedToXContentHelper.startArray(CHOICES_FIELD), + Iterators.flatMap(choices.iterator(), c -> c.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + + Iterator usageIterator = Collections.emptyIterator(); + if (usage != null) { + usageIterator = Iterators.concat( + ChunkedToXContentHelper.startObject(USAGE_FIELD), + ChunkedToXContentHelper.field(COMPLETION_TOKENS_FIELD, usage.completionTokens()), + ChunkedToXContentHelper.field(PROMPT_TOKENS_FIELD, usage.promptTokens()), + ChunkedToXContentHelper.field(TOTAL_TOKENS_FIELD, usage.totalTokens()), + ChunkedToXContentHelper.endObject() + ); + } + + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(ID_FIELD, id), + choicesIterator, + ChunkedToXContentHelper.field(MODEL_FIELD, model), + ChunkedToXContentHelper.field(OBJECT_FIELD, object), + usageIterator, + ChunkedToXContentHelper.endObject() + ); + } + + public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { + + /* + choices: Array<{ + delta: { ... }; + finish_reason: string | null; + index: number; + }>; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + delta.toXContentChunked(params), + ChunkedToXContentHelper.optionalField(FINISH_REASON_FIELD, finishReason), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.endObject() + ); + } + + public static class Delta { + private final String content; + private final String refusal; + private final String role; + private List toolCalls; + + public Delta(String content, String refusal, String role, List toolCalls) { + this.content = content; + this.refusal = refusal; + this.role = role; + this.toolCalls = toolCalls; + } + + /* + delta: { + content?: string | null; + refusal?: string | null; + role?: 'system' | 'user' | 'assistant' | 'tool'; + tool_calls?: Array<{ ... }>; + }; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var xContent = Iterators.concat( + ChunkedToXContentHelper.startObject(DELTA_FIELD), + ChunkedToXContentHelper.optionalField(CONTENT_FIELD, content), + ChunkedToXContentHelper.optionalField(REFUSAL_FIELD, refusal), + ChunkedToXContentHelper.optionalField(ROLE_FIELD, role) + ); + + if (toolCalls != null && toolCalls.isEmpty() == false) { + xContent = Iterators.concat( + xContent, + ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + xContent = Iterators.concat(xContent, ChunkedToXContentHelper.endObject()); + return xContent; + + } + + public String getContent() { + return content; + } + + public String getRefusal() { + return refusal; + } + + public String getRole() { + return role; + } + + public List getToolCalls() { + return toolCalls; + } + + public static class ToolCall { + private final int index; + private final String id; + public ChatCompletionChunk.Choice.Delta.ToolCall.Function function; + private final String type; + + public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) { + this.index = index; + this.id = id; + this.function = function; + this.type = type; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() { + return function; + } + + public String getType() { + return type; + } + + /* + index: number; + id?: string; + function?: { + arguments?: string; + name?: string; + }; + type?: 'function'; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var content = Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.optionalField(ID_FIELD, id) + ); + + if (function != null) { + content = Iterators.concat( + content, + ChunkedToXContentHelper.startObject(FUNCTION_FIELD), + ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()), + ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.getName()), + ChunkedToXContentHelper.endObject() + ); + } + + content = Iterators.concat( + content, + ChunkedToXContentHelper.field(TYPE_FIELD, type), + ChunkedToXContentHelper.endObject() + ); + return content; + } + + public static class Function { + private final String arguments; + private final String name; + + public Function(String arguments, String name) { + this.arguments = arguments; + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public String getName() { + return name; + } + } + } + } + } + + public record Usage(int completionTokens, int promptTokens, int totalTokens) {} + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a9ca5e6da8720..01c0ff88be222 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -41,8 +41,7 @@ protected InferenceAction.Request createTestInstance() { return new InferenceAction.Request( randomFrom(TaskType.values()), randomAlphaOfLength(6), - // null, - randomNullOrAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java new file mode 100644 index 0000000000000..1872ac3caa230 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage_Is_Null() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(null), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be null;")); + } + + public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_EmptyArray() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(List.of()), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be an empty array;")); + } + + public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.SPARSE_EMBEDDING, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + } + + public void testValidation_ReturnsNull_When_TaskType_IsAny() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + assertNull(request.validate()); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionAction.Request::new; + } + + @Override + protected UnifiedCompletionAction.Request createTestInstance() { + return new UnifiedCompletionAction.Request( + randomAlphaOfLength(10), + randomFrom(TaskType.values()), + UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + ); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstance(UnifiedCompletionAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java new file mode 100644 index 0000000000000..47a0814a584b7 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -0,0 +1,293 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ], + "name": "a name", + "tool_call_id": "100", + "tool_calls": [ + { + "id": "call_62136354", + "type": "function", + "function": { + "arguments": "{'order_id': 'order_12345'}", + "name": "get_delivery_date" + } + } + ] + } + ], + "max_completion_tokens": 100, + "stop": ["stop"], + "temperature": 0.1, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": { + "type": "function", + "function": { + "name": "some function" + } + }, + "top_p": 0.2 + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("some text", "string")) + ), + "user", + "a name", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gpt-4o", + 100L, + List.of("stop"), + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F + ); + + assertThat(request, is(expected)); + } + } + + public void testParsing() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston today?" + } + ], + "stop": "none", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": "auto" + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"), + "user", + null, + null, + null + ) + ), + "gpt-4o", + null, + List.of("none"), + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + null + ); + + assertThat(request, is(expected)); + } + } + + public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { + return new UnifiedCompletionRequest( + randomList(5, UnifiedCompletionRequestTests::randomMessage), + randomAlphaOfLengthOrNull(10), + randomPositiveLongOrNull(), + randomStopOrNull(), + randomFloatOrNull(), + randomToolChoiceOrNull(), + randomToolListOrNull(), + randomFloatOrNull() + ); + } + + public static UnifiedCompletionRequest.Message randomMessage() { + return new UnifiedCompletionRequest.Message( + randomContent(), + randomAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), + randomAlphaOfLengthOrNull(10), + randomToolCallListOrNull() + ); + } + + public static UnifiedCompletionRequest.Content randomContent() { + return randomBoolean() + ? new UnifiedCompletionRequest.ContentString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ContentObjects(randomList(10, UnifiedCompletionRequestTests::randomContentObject)); + } + + public static UnifiedCompletionRequest.ContentObject randomContentObject() { + return new UnifiedCompletionRequest.ContentObject(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomToolCallListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomToolCall) : null; + } + + public static UnifiedCompletionRequest.ToolCall randomToolCall() { + return new UnifiedCompletionRequest.ToolCall(randomAlphaOfLength(10), randomToolCallFunctionField(), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunctionField() { + return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomStopOrNull() { + return randomBoolean() ? randomStop() : null; + } + + public static List randomStop() { + return randomList(5, () -> randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() { + return randomBoolean() ? randomToolChoice() : null; + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoice() { + return randomBoolean() + ? new UnifiedCompletionRequest.ToolChoiceString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ToolChoiceObject(randomAlphaOfLength(10), randomToolChoiceObjectFunctionField()); + } + + public static UnifiedCompletionRequest.ToolChoiceObject.FunctionField randomToolChoiceObjectFunctionField() { + return new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomAlphaOfLength(10)); + } + + public static List randomToolListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomTool) : null; + } + + public static UnifiedCompletionRequest.Tool randomTool() { + return new UnifiedCompletionRequest.Tool(randomAlphaOfLength(10), randomToolFunctionField()); + } + + public static UnifiedCompletionRequest.Tool.FunctionField randomToolFunctionField() { + return new UnifiedCompletionRequest.Tool.FunctionField( + randomAlphaOfLengthOrNull(10), + randomAlphaOfLength(10), + null, + randomOptionalBoolean() + ); + } + + @Override + protected UnifiedCompletionRequest mutateInstanceForVersion(UnifiedCompletionRequest instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionRequest::new; + } + + @Override + protected UnifiedCompletionRequest createTestInstance() { + return randomUnifiedCompletionRequest(); + } + + @Override + protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java new file mode 100644 index 0000000000000..a8f569dbef9d1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { + + public void testResults_toXContentChunked() throws IOException { + String expected = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "chunk1", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ) + ), + "example_model", + "example_object", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) + ); + + Deque deque = new ArrayDeque<>(); + deque.add(chunk); + StreamingUnifiedChatCompletionResults.Results results = new StreamingUnifiedChatCompletionResults.Results(deque); + XContentBuilder builder = JsonXContent.contentBuilder(); + results.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testChoiceToXContentChunked() throws IOException { + String expected = """ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + choice.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testToolCallToXContentChunked() throws IOException { + String expected = """ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + toolCall.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 17579fd6368ce..eeffa1db54856 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4175,6 +4175,7 @@ public void testInferenceUserRole() { assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 86c0128a3e53c..1716057cdfe46 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -21,6 +21,9 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; @@ -341,10 +344,21 @@ protected Deque streamInferOnMockService(String modelId, TaskTy return callAsync(endpoint, input); } + protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) + throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); + return callAsyncUnified(endpoint, input, "user"); + } + private Deque callAsync(String endpoint, List input) throws Exception { - var responseConsumer = new AsyncInferenceResponseConsumer(); var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input, null)); + + return execAsyncCall(request); + } + + private Deque execAsyncCall(Request request) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @@ -362,6 +376,22 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } + private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + var request = new Request("POST", endpoint); + + request.setJsonEntity(createUnifiedJsonBody(input, role)); + return execAsyncCall(request); + } + + private String createUnifiedJsonBody(List input, String role) throws IOException { + var messages = input.stream().map(i -> Map.of("content", i, "role", role)).toList(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("messages", messages); + builder.endObject(); + return org.elasticsearch.common.Strings.toString(builder); + } + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferInternal(endpoint, input, null, Map.of()); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 604e1d4f553b2..2099ec8287a76 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -11,13 +11,18 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -481,6 +486,56 @@ public void testSupportedStream() throws Exception { } } + public void testUnifiedCompletionInference() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomUUID()).toList(); + try { + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var expectedResponses = expectedResultsIterator(input); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } + + private static Iterator expectedResultsIterator(List input) { + return Stream.concat(input.stream().map(String::toUpperCase).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")).iterator(); + } + + private static String expectedResult(String input) { + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("id", "id"); + builder.startArray("choices"); + builder.startObject(); + builder.startObject("delta"); + builder.field("content", input); + builder.endObject(); + builder.field("index", 0); + builder.endObject(); + builder.endArray(); + builder.field("model", "gpt-4o-2024-08-06"); + builder.field("object", "chat.completion.chunk"); + builder.endObject(); + + return Strings.toString(builder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public void testGetZeroModels() throws IOException { var models = getModels("_all", TaskType.COMPLETION); assertThat(models, empty()); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ae11a02d312e2..f5f682b143a72 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -132,6 +133,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 9320571572f0a..fa1e27005c287 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -120,6 +121,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index fe0223cce0323..64569fd8c5c6a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -123,6 +124,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("unifiedCompletionInfer not supported"); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6d7983bc8cb53..f7a05a27354ef 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -30,12 +30,14 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import java.io.IOException; import java.util.EnumSet; @@ -121,6 +123,24 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + private StreamingChatCompletionResults makeResults(List input) { var responseIter = input.stream().map(String::toUpperCase).iterator(); return new StreamingChatCompletionResults(subscriber -> { @@ -152,6 +172,59 @@ private ChunkedToXContent completionChunk(String delta) { ); } + private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) { + var responseIter = request.messages().stream().map(message -> message.content().toString().toUpperCase()).iterator(); + return new StreamingUnifiedChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(unifiedCompletionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + /* + The response format looks like this + { + "id": "chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF", + "choices": [ + { + "delta": { + "content": " information" + }, + "index": 0 + } + ], + "model": "gpt-4o-2024-08-06", + "object": "chat.completion.chunk" + } + */ + private ChunkedToXContent unifiedCompletionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("id", "id"), + ChunkedToXContentHelper.startArray("choices"), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startObject("delta"), + ChunkedToXContentHelper.field("content", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.field("index", 0), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"), + ChunkedToXContentHelper.field("object", "chat.completion.chunk"), + ChunkedToXContentHelper.endObject() + ); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 673b841317a3d..a4187f4c4fa90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -137,11 +138,18 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); + addUnifiedNamedWriteables(namedWriteables); + namedWriteables.addAll(StreamingTaskManager.namedWriteables()); return namedWriteables; } + private static void addUnifiedNamedWriteables(List namedWriteables) { + var writeables = UnifiedCompletionRequest.getNamedWriteables(); + namedWriteables.addAll(writeables); + } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index d7d623ab20143..148a784456361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -51,6 +51,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; @@ -59,6 +60,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; @@ -86,6 +88,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; +import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -159,8 +162,9 @@ public InferencePlugin(Settings settings) { @Override public List> getActions() { - return List.of( + var availableActions = List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -169,6 +173,13 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler<>(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class) ); + + List> conditionalActions = + UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)) + : List.of(); + + return Stream.concat(availableActions.stream(), conditionalActions.stream()).toList(); } @Override @@ -183,7 +194,7 @@ public List getRestHandlers( Supplier nodesInCluster, Predicate clusterSupportsFeature ) { - return List.of( + var availableRestActions = List.of( new RestInferenceAction(), new RestStreamInferenceAction(), new RestGetInferenceModelAction(), @@ -193,6 +204,11 @@ public List getRestHandlers( new RestGetInferenceDiagnosticsAction(), new RestGetInferenceServicesAction() ); + List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new RestUnifiedCompletionInferenceAction()) + : List.of(); + + return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java new file mode 100644 index 0000000000000..3e13d0c1e39de --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Unified Completion feature flag. When the feature is complete, this flag will be removed. + * Enable feature via JVM option: `-Des.inference_unified_feature_flag_enabled=true`. + */ +public class UnifiedCompletionFeature { + public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("inference_unified"); + + private UnifiedCompletionFeature() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java new file mode 100644 index 0000000000000..2a0e8e1775279 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; + +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; + +public abstract class BaseTransportInferenceAction extends HandledTransportAction< + Request, + InferenceAction.Response> { + + private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class); + private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; + private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final InferenceStats inferenceStats; + private final StreamingTaskManager streamingTaskManager; + + public BaseTransportInferenceAction( + String inferenceActionName, + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager, + Writeable.Reader requestReader + ) { + super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.inferenceStats = inferenceStats; + this.streamingTaskManager = streamingTaskManager; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + var timer = InferenceTimer.start(); + + var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()); + try { + validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); + validationHelper( + () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, + () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType()) + ); + validationHelper( + () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), + () -> createInvalidTaskTypeException(request, unparsedModel) + ); + } catch (Exception e) { + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + var model = service.get() + .parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + }, e -> { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); + } catch (Exception metricsException) { + log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); + } + listener.onFailure(e); + }); + + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } + + private static void validationHelper(Supplier validationFailure, Supplier exceptionCreator) { + if (validationFailure.get()) { + throw exceptionCreator.get(); + } + } + + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); + + protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); + + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); + } + } + + private void inferOnServiceWithMetrics( + Model model, + Request request, + InferenceService service, + InferenceTimer timer, + ActionListener listener + ) { + inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); + inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { + if (request.isStreaming()) { + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + + var instrumentedStream = new PublisherWithMetrics(timer, model); + taskProcessor.subscribe(instrumentedStream); + + listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); + } else { + recordMetrics(model, timer, null); + listener.onResponse(new InferenceAction.Response(inferenceResults)); + } + }, e -> { + recordMetrics(model, timer, e); + listener.onFailure(e); + })); + } + + private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); + } + } + + private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + doInference(model, request, service, listener); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + protected abstract void doInference( + Model model, + Request request, + InferenceService service, + ActionListener listener + ); + + private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + + private class PublisherWithMetrics extends DelegatingProcessor { + + private final InferenceTimer timer; + private final Model model; + + private PublisherWithMetrics(InferenceTimer timer, Model model) { + this.timer = timer; + this.model = model; + } + + @Override + protected void next(ChunkedToXContent item) { + downstream().onNext(item); + } + + @Override + public void onError(Throwable throwable) { + recordMetrics(model, timer, throwable); + super.onError(throwable); + } + + @Override + protected void onCancel() { + recordMetrics(model, timer, null); + super.onCancel(); + } + + @Override + public void onComplete() { + recordMetrics(model, timer, null); + super.onComplete(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ba9ab3c133731..08e6d869a553d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -7,47 +7,22 @@ package org.elasticsearch.xpack.inference.action; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.util.stream.Collectors; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; - -public class TransportInferenceAction extends HandledTransportAction { - private static final Logger log = LogManager.getLogger(TransportInferenceAction.class); - private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; - private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - - private final ModelRegistry modelRegistry; - private final InferenceServiceRegistry serviceRegistry; - private final InferenceStats inferenceStats; - private final StreamingTaskManager streamingTaskManager; +public class TransportInferenceAction extends BaseTransportInferenceAction { @Inject public TransportInferenceAction( @@ -58,184 +33,44 @@ public TransportInferenceAction( InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager ) { - super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.modelRegistry = modelRegistry; - this.serviceRegistry = serviceRegistry; - this.inferenceStats = inferenceStats; - this.streamingTaskManager = streamingTaskManager; + super( + InferenceAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + InferenceAction.Request::new + ); } @Override - protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - var timer = InferenceTimer.start(); - - var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { - var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { - // not the wildcard task type and not the model task type - var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - var model = service.get() - .parsePersistedConfigWithSecrets( - unparsedModel.inferenceEntityId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ); - inferOnServiceWithMetrics(model, request, service.get(), timer, listener); - }, e -> { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); - } catch (Exception metricsException) { - log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); - } - listener.onFailure(e); - }); - - modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); - } - - private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } - } - - private void inferOnServiceWithMetrics( - Model model, - InferenceAction.Request request, - InferenceService service, - InferenceTimer timer, - ActionListener listener - ) { - inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); - inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { - if (request.isStreaming()) { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); - inferenceResults.publisher().subscribe(taskProcessor); - - var instrumentedStream = new PublisherWithMetrics(timer, model); - taskProcessor.subscribe(instrumentedStream); - - listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); - } else { - recordMetrics(model, timer, null); - listener.onResponse(new InferenceAction.Response(inferenceResults)); - } - }, e -> { - recordMetrics(model, timer, e); - listener.onFailure(e); - })); + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + return false; } - private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { + return null; } - private void inferOnService( + @Override + protected void doInference( Model model, InferenceAction.Request request, InferenceService service, ActionListener listener ) { - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); - } else { - listener.onFailure(unsupportedStreamingTaskException(request, service)); - } - } - - private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { - var supportedTasks = service.supportedStreamingTasks(); - if (supportedTasks.isEmpty()) { - return new ElasticsearchStatusException( - format("Streaming is not allowed for service [%s].", service.name()), - RestStatus.METHOD_NOT_ALLOWED - ); - } else { - var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); - return new ElasticsearchStatusException( - format( - "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", - service.name(), - request.getTaskType(), - validTasks - ), - RestStatus.METHOD_NOT_ALLOWED - ); - } - } - - private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { - return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); - } - - private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { - return new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - requested, - expected + service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener ); } - - private class PublisherWithMetrics extends DelegatingProcessor { - private final InferenceTimer timer; - private final Model model; - - private PublisherWithMetrics(InferenceTimer timer, Model model) { - this.timer = timer; - this.model = model; - } - - @Override - protected void next(ChunkedToXContent item) { - downstream().onNext(item); - } - - @Override - public void onError(Throwable throwable) { - recordMetrics(model, timer, throwable); - super.onError(throwable); - } - - @Override - protected void onCancel() { - recordMetrics(model, timer, null); - super.onCancel(); - } - - @Override - public void onComplete() { - recordMetrics(model, timer, null); - super.onComplete(); - } - } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..f0906231d8f42 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction { + + @Inject + public TransportUnifiedCompletionInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + super( + UnifiedCompletionAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + UnifiedCompletionAction.Request::new + ); + } + + @Override + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { + return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; + } + + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException( + UnifiedCompletionAction.Request request, + UnparsedModel unparsedModel + ) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + request.getTaskType(), + TaskType.COMPLETION.toString() + ); + } + + @Override + protected void doInference( + Model model, + UnifiedCompletionAction.Request request, + InferenceService service, + ActionListener listener + ) { + service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 03e794e42c3a2..eda3fc0f3bfdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -9,7 +9,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -25,6 +32,33 @@ public abstract class DelegatingProcessor implements Flow.Processor private Flow.Subscriber downstream; private Flow.Subscription upstream; + public static Deque parseEvent( + Deque item, + ParseChunkFunction parseFunction, + XContentParserConfiguration parserConfig, + Logger logger + ) throws Exception { + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parseFunction.apply(parserConfig, event); + delta.forEachRemaining(results::offer); + } catch (Exception e) { + logger.warn("Failed to parse event from inference provider: {}", event); + throw e; + } + } + } + + return results; + } + + @FunctionalInterface + public interface ParseChunkFunction { + Iterator apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException; + } + @Override public void subscribe(Flow.Subscriber subscriber) { if (downstream != null) { @@ -51,7 +85,7 @@ public void request(long n) { if (isClosed.get()) { downstream.onComplete(); } else if (upstream != null) { - upstream.request(n); + upstreamRequest(n); } else { pendingRequests.accumulateAndGet(n, Long::sum); } @@ -67,6 +101,13 @@ public void cancel() { }; } + /** + * Guaranteed to be called when the upstream is set and this processor had not been closed. + */ + protected void upstreamRequest(long n) { + upstream.request(n); + } + protected void onCancel() {} @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java index 4e97554b56445..b43e5ab70e2f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -34,13 +33,7 @@ public SingleInputSenderExecutableAction( @Override public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { - if (inferenceInputs instanceof DocumentsOnlyInput == false) { - listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); - return; - } - - var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; - if (docsOnlyInput.getInputs().size() > 1) { + if (inferenceInputs.inputSize() > 1) { listener.onFailure( new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index 9c83264b5581f..bd5c53d589df0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -26,7 +26,7 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. */ public class OpenAiActionCreator implements OpenAiActionVisitor { - private static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; + public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java index a0a44e62f9f73..e7a960f1316f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java @@ -69,7 +69,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + List input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs(); AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 69a5c665feb86..3929585a0745d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -44,10 +44,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 5418b3dd9840b..6d4aeb9e31bac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 21cec68b14a49..affd2e3a7760e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -41,10 +41,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, inputs, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index d036559ec3dcb..c2f5f3e9db5ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java new file mode 100644 index 0000000000000..928da95d9c2f0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the input text passed by the request and indicates whether the response should be streamed. + * The main difference between this class and {@link UnifiedChatInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#infer} code path. These are requests sent to the + * API without using the _unified route. + */ +public class ChatCompletionInput extends InferenceInputs { + private final List input; + + public ChatCompletionInput(List input) { + this(input, false); + } + + public ChatCompletionInput(List input, boolean stream) { + super(stream); + this.input = Objects.requireNonNull(input); + } + + public List getInputs() { + return this.input; + } + + public int inputSize() { + return input.size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index ae46fbe0fef87..40cd03c87664e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -50,10 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index 8cf411d84c932..3feb79d3de6cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -14,30 +14,28 @@ public class DocumentsOnlyInput extends InferenceInputs { public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof DocumentsOnlyInput == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, DocumentsOnlyInput.class); } return (DocumentsOnlyInput) inferenceInputs; } private final List input; - private final boolean stream; public DocumentsOnlyInput(List input) { this(input, false); } public DocumentsOnlyInput(List input, boolean stream) { - super(); + super(stream); this.input = Objects.requireNonNull(input); - this.stream = stream; } public List getInputs() { return this.input; } - public boolean stream() { - return stream; + public int inputSize() { + return input.size(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index abe50c6fae3f9..0097f9c08ea21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -51,7 +51,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(DocumentsOnlyInput.of(inferenceInputs), model); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest( + inferenceInputs.castTo(ChatCompletionInput.class), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index dd241857ef0c4..e85ea6f1d9b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -10,7 +10,29 @@ import org.elasticsearch.common.Strings; public abstract class InferenceInputs { - public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { - return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + private final boolean stream; + + public InferenceInputs(boolean stream) { + this.stream = stream; + } + + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class clazz) { + return new IllegalArgumentException( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz) + ); } + + public T castTo(Class clazz) { + if (clazz.isInstance(this) == false) { + throw createUnsupportedTypeException(this, clazz); + } + + return clazz.cast(this); + } + + public boolean stream() { + return stream; + } + + public abstract int inputSize(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index cea89332e5bf0..4d730be6aa6bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,7 +15,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -25,8 +25,8 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); - private static final ResponseHandler HANDLER = createCompletionHandler(); + static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); @@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, private final OpenAiChatCompletionModel model; private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); this.model = Objects.requireNonNull(model); } @@ -46,10 +46,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); + var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); + var request = new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(chatCompletionInputs, USER_ROLE), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..3b0f770e3e061 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { + + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { + return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final OpenAiChatCompletionModel model; + + private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 50bb77b307db3..5af5245ac5b40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -14,7 +14,7 @@ public class QueryAndDocsInputs extends InferenceInputs { public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof QueryAndDocsInputs == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, QueryAndDocsInputs.class); } return (QueryAndDocsInputs) inferenceInputs; @@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final String query; private final List chunks; - private final boolean stream; public QueryAndDocsInputs(String query, List chunks) { this(query, chunks, false); } public QueryAndDocsInputs(String query, List chunks, boolean stream) { - super(); + super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); - this.stream = stream; } public String getQuery() { @@ -43,8 +41,7 @@ public List getChunks() { return chunks; } - public boolean stream() { - return stream; + public int inputSize() { + return chunks.size(); } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java new file mode 100644 index 0000000000000..f89fa1ee37a6f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnifiedCompletionRequest; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the unified request. + * The main difference between this class and {@link ChatCompletionInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#unifiedCompletionInfer(Model, UnifiedCompletionRequest, TimeValue, ActionListener)} + * code path. These are requests sent to the API with the _unified route. + */ +public class UnifiedChatInput extends InferenceInputs { + private final UnifiedCompletionRequest request; + + public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { + super(stream); + this.request = Objects.requireNonNull(request); + } + + public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { + this(completionInput.getInputs(), roleValue, completionInput.stream()); + } + + public UnifiedChatInput(List inputs, String roleValue, boolean stream) { + this(UnifiedCompletionRequest.of(convertToMessages(inputs, roleValue)), stream); + } + + private static List convertToMessages(List inputs, String roleValue) { + return inputs.stream() + .map( + value -> new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(value), + roleValue, + null, + null, + null + ) + ) + .toList(); + } + + public UnifiedCompletionRequest getRequest() { + return request; + } + + public int inputSize() { + return request.messages().size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 6e006fe255956..48c8132035b50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -18,10 +18,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -115,19 +113,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - var results = new ArrayDeque(item.size()); - for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try { - var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - log.warn("Failed to parse event from inference provider: {}", event); - throw e; - } - } - } + var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log); if (results.isEmpty()) { upstream().request(1); @@ -136,7 +122,7 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) + private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { return Collections.emptyIterator(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..fce2556efc5e0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java new file mode 100644 index 0000000000000..599d71df3dcfa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + public static final String FUNCTION_FIELD = "function"; + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); + + private static final String CHOICES_FIELD = "choices"; + private static final String DELTA_FIELD = "delta"; + private static final String CONTENT_FIELD = "content"; + private static final String DONE_MESSAGE = "[done]"; + private static final String REFUSAL_FIELD = "refusal"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ROLE_FIELD = "role"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String INDEX_FIELD = "index"; + public static final String OBJECT_FIELD = "object"; + public static final String MODEL_FIELD = "model"; + public static final String ID_FIELD = "id"; + public static final String CHOICE_FIELD = "choice"; + public static final String USAGE_FIELD = "usage"; + public static final String TYPE_FIELD = "type"; + public static final String NAME_FIELD = "name"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + + private final Deque buffer = new LinkedBlockingDeque<>(); + + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger); + + if (results.isEmpty()) { + upstream().request(1); + } else if (results.size() == 1) { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + } + } + + private static Iterator parse( + XContentParserConfiguration parserConfig, + ServerSentEvent event + ) throws IOException { + if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + return Collections.emptyIterator(); + } + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); + + return Collections.singleton(chunk).iterator(); + } + } + + public static class ChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "chat_completion_chunk", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + (String) args[0], + (List) args[1], + (String) args[2], + (String) args[3], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), + new ParseField(CHOICES_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); + PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), + null, + new ParseField(USAGE_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private static class ChoiceParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + CHOICE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0], + (String) args[1], + (int) args[2] + ) + ); + + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.DeltaParser.parse(p), + new ParseField(DELTA_FIELD) + ); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } + + private static class DeltaParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta, + Void> PARSER = new ConstructingObjectParser<>( + DELTA_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + (String) args[0], + (String) args[1], + (String) args[2], + (List) args[3] + ) + ); + + static { + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), + new ParseField(TOOL_CALLS_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class ToolCallParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall, + Void> PARSER = new ConstructingObjectParser<>( + "tool_call", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + (int) args[0], + (String) args[1], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], + (String) args[3] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.FunctionParser.parse(p), + new ParseField(FUNCTION_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TYPE_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class FunctionParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function, + Void> PARSER = new ConstructingObjectParser<>( + FUNCTION_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + (String) args[0], + (String) args[1] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( + XContentParser parser + ) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class UsageParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + USAGE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(COMPLETION_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(PROMPT_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(TOTAL_TOKENS_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java index 80770d63ef139..b1af18d03dda4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; @@ -27,13 +27,13 @@ public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { private static final String ALT_PARAM = "alt"; private static final String SSE_VALUE = "sse"; - private final DocumentsOnlyInput input; + private final ChatCompletionInput input; private final LazyInitializable uri; private final GoogleAiStudioCompletionModel model; - public GoogleAiStudioCompletionRequest(DocumentsOnlyInput input, GoogleAiStudioCompletionModel model) { + public GoogleAiStudioCompletionRequest(ChatCompletionInput input, GoogleAiStudioCompletionModel model) { this.input = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); this.uri = new LazyInitializable<>(() -> model.uri(input.stream())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java deleted file mode 100644 index 867a7ca80cbcb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class OpenAiChatCompletionRequestEntity implements ToXContentObject { - - private static final String MESSAGES_FIELD = "messages"; - private static final String MODEL_FIELD = "model"; - - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - - private static final String ROLE_FIELD = "role"; - private static final String USER_FIELD = "user"; - private static final String CONTENT_FIELD = "content"; - private static final String STREAM_FIELD = "stream"; - - private final List messages; - private final String model; - - private final String user; - private final boolean stream; - - public OpenAiChatCompletionRequestEntity(List messages, String model, String user, boolean stream) { - Objects.requireNonNull(messages); - Objects.requireNonNull(model); - - this.messages = messages; - this.model = model; - this.user = user; - this.stream = stream; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (String message : messages) { - builder.startObject(); - - { - builder.field(ROLE_FIELD, USER_FIELD); - builder.field(CONTENT_FIELD, message); - } - - builder.endObject(); - } - } - builder.endArray(); - - builder.field(MODEL_FIELD, model); - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (Strings.isNullOrEmpty(user) == false) { - builder.field(USER_FIELD, user); - } - - if (stream) { - builder.field(STREAM_FIELD, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 99a025e70d003..2e6bdb748fd33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -13,6 +13,7 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -21,24 +22,21 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiChatCompletionRequest implements OpenAiRequest { +public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { private final OpenAiAccount account; - private final List input; private final OpenAiChatCompletionModel model; - private final boolean stream; + private final UnifiedChatInput unifiedChatInput; - public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { - this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); + public OpenAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); - this.stream = stream; } @Override @@ -46,9 +44,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) - ).getBytes(StandardCharsets.UTF_8) + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -87,7 +83,7 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return stream; + return unifiedChatInput.stream(); } public static URI buildDefaultUri() throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..50339bf851f7d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String USER_FIELD = "user"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + private static final String MODEL_FIELD = "model"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + private final OpenAiChatCompletionModel model; + + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + Objects.requireNonNull(unifiedChatInput); + + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field(NAME_FIELD, message.name()); + } + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); + } + + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, t.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, t.function().description()); + builder.field(NAME_FIELD, t.function().name()); + builder.field(PARAMETERS_FIELD, t.function().parameters()); + if (t.function().strict() != null) { + builder.field(STRICT_FIELD, t.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { + builder.field(USER_FIELD, model.getTaskSettings().user()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index e72e68052f648..d911158e82296 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; @@ -21,27 +22,32 @@ import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; abstract class BaseInferenceAction extends BaseRestHandler { - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; + static Params parseParams(RestRequest restRequest) { if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + var inferenceEntityId = restRequest.param(INFERENCE_ID); + var taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + return new Params(inferenceEntityId, taskType); } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; + return new Params(restRequest.param(TASK_TYPE_OR_INFERENCE_ID), TaskType.ANY); } + } + + record Params(String inferenceEntityId, TaskType taskType) {} + + static TimeValue parseTimeout(RestRequest restRequest) { + return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = parseParams(restRequest); InferenceAction.Request.Builder requestBuilder; try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); } - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var inferTimeout = parseTimeout(restRequest); requestBuilder.setInferenceTimeout(inferTimeout); var request = prepareInferenceRequest(requestBuilder); return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 55d6443b43c03..c46f211bb26af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,6 +30,12 @@ public final class Paths { + "}/{" + INFERENCE_ID + "}/_stream"; + static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified"; + static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_unified"; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..5c71b560a6b9d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { + @Override + public String getName() { + return "unified_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = BaseInferenceAction.parseParams(restRequest); + + var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); + + UnifiedCompletionAction.Request request; + try (var parser = restRequest.contentParser()) { + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); + } + + return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 8e2dac1ef9db2..e9b75e9ec7796 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.services; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceService; @@ -17,11 +19,15 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.io.IOException; import java.util.EnumSet; @@ -61,11 +67,31 @@ public void infer( ActionListener listener ) { init(); - if (query != null) { - doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener); - } else { - doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener); - } + var inferenceInput = createInput(model, input, query, stream); + doInfer(model, inferenceInput, taskSettings, inputType, timeout, listener); + } + + private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { + return switch (model.getTaskType()) { + case COMPLETION -> new ChatCompletionInput(input, stream); + case RERANK -> new QueryAndDocsInputs(query, input, stream); + case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); + default -> throw new ElasticsearchStatusException( + Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), + RestStatus.BAD_REQUEST + ); + }; + } + + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + init(); + doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); } @Override @@ -92,6 +118,13 @@ protected abstract void doInfer( ActionListener listener ); + protected abstract void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ); + protected abstract void doChunkedInfer( Model model, DocumentsOnlyInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index ec4b8d9bb4d3d..7d05bac363fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -776,5 +776,9 @@ public static T nonNullOrDefault(@Nullable T requestValue, @Nullable T origi return requestValue == null ? originalSettingsValue : requestValue; } + public static void throwUnsupportedUnifiedCompletionOperation(String serviceName) { + throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName)); + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 5adc2a11b19d9..ffd26b9ac534d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -57,14 +58,13 @@ import java.util.Map; import java.util.stream.Stream; -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HOST; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME; @@ -261,6 +261,16 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta ); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 48b3c3df03e11..d224e50bb650d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; @@ -89,6 +91,16 @@ public AmazonBedrockService( this.amazonBedrockSender = amazonBedrockFactory.createSender(); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index b3d503de8e3eb..f1840af18779f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -52,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class AnthropicService extends SenderService { public static final String NAME = "anthropic"; @@ -192,6 +194,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index bba331fc0b5df..f8ea11e4b15a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -63,6 +64,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD; @@ -81,6 +83,16 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents super(factory, serviceComponents); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 16c94dfa9ad94..a38c265d2613c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -233,6 +235,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b3d8b3b6efce3..ccb8d79dacd6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class CohereService extends SenderService { @@ -232,6 +234,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index b256861e7dd27..fe8ee52eb8816 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class ElasticInferenceService extends SenderService { @@ -76,6 +78,16 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 0e64842f873d3..5f613d6be5869 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; @@ -77,6 +78,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; @@ -578,6 +580,16 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE ); } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void infer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 57a8a66a3f3a6..b681722a82136 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class GoogleAiStudioService extends SenderService { @@ -282,9 +284,8 @@ protected void doInfer( ) { if (model instanceof GoogleAiStudioCompletionModel completionModel) { var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool()); - var docsOnly = DocumentsOnlyInput.of(inputs); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - completionModel.uri(docsOnly.stream()), + completionModel.uri(inputs.stream()), "Google AI Studio completion" ); var action = new SingleInputSenderExecutableAction( @@ -308,6 +309,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 857d475499aae..87a2d98dca92c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; @@ -206,6 +208,16 @@ protected void doInfer( action.execute(inputs, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 51cca72f26054..b74ec01cd76e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SettingsConfiguration; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -47,6 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; @@ -139,6 +142,16 @@ protected void doChunkedInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 75920efa251f2..5b038781b96af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; @@ -49,6 +50,7 @@ import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -81,6 +83,16 @@ protected HuggingFaceModel createModel( }; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 981a3e95808ef..cc66d5fd7ee74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; @@ -276,6 +278,16 @@ protected void doInfer( action.execute(input, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index fe0edb851902b..881e7d36f2a21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralService extends SenderService { @@ -88,6 +90,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 20ff1c617d21f..7b51b068708ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -32,10 +32,13 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -53,6 +56,8 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.COMPLETION_ERROR_PREFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -257,6 +262,28 @@ public void doInfer( action.execute(inputs, timeout, listener); } + @Override + public void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; + + var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); + var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + + action.execute(inputs, timeout, listener); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index e721cd2955cf3..7d79d64b3a771 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -24,6 +25,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; @@ -38,6 +40,26 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } + public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.uri(), + originalModelServiceSettings.organizationId(), + originalModelServiceSettings.maxInputTokens(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new OpenAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + public OpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index 8029d8579baba..7ef7f85d71a6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -48,5 +48,4 @@ public static OpenAiChatCompletionRequestTaskSettings fromMap(Map TaskType.fromStringOrStatusException(null)); + assertThat(exception.getMessage(), Matchers.is("Task type must not be null")); + + exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah")); + assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]")); + + assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 5abb9000f4d04..9395ae222e9ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; @@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName var mockConfigs = mock(ModelConfigurations.class); when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockConfigs.getService()).thenReturn(serviceName); + when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); var mockModel = mock(Model.class); when(mockModel.getConfigurations()).thenReturn(mockConfigs); + when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); return mockModel; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java new file mode 100644 index 0000000000000..47f3a0e0b57aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { + private ModelRegistry modelRegistry; + private StreamingTaskManager streamingTaskManager; + private BaseTransportInferenceAction action; + + protected static final String serviceId = "serviceId"; + protected static final TaskType taskType = TaskType.COMPLETION; + protected static final String inferenceId = "inferenceEntityId"; + protected InferenceServiceRegistry serviceRegistry; + protected InferenceStats inferenceStats; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportService transportService = mock(); + ActionFilters actionFilters = mock(); + modelRegistry = mock(); + serviceRegistry = mock(); + inferenceStats = new InferenceStats(mock(), mock()); + streamingTaskManager = mock(); + action = createAction(transportService, actionFilters, modelRegistry, serviceRegistry, inferenceStats, streamingTaskManager); + } + + protected abstract BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ); + + protected abstract Request createRequest(); + + public void testMetricsAfterModelRegistryError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = doExecute(taskType); + verify(listener).onFailure(same(expectedException)); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + protected ActionListener doExecute(TaskType taskType) { + return doExecute(taskType, false); + } + + protected ActionListener doExecute(TaskType taskType, boolean stream) { + Request request = createRequest(); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(stream); + ActionListener listener = mock(); + action.doExecute(mock(), request, listener); + return listener; + } + + public void testMetricsAfterMissingService() { + mockModelRegistry(taskType); + + when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); + + var listener = doExecute(taskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + protected void mockModelRegistry(TaskType expectedTaskType) { + var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + } + + public void testMetricsAfterUnknownTaskType() { + var modelTaskType = TaskType.RERANK; + var requestTaskType = TaskType.SPARSE_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is( + "Incompatible task_type, the requested type [" + + requestTaskType + + "] does not match the model type [" + + modelTaskType + + "]" + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterInferError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockService(listener -> listener.onFailure(expectedException)); + + var listener = doExecute(taskType); + + verify(listener).onFailure(same(expectedException)); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamUnsupported() { + var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; + var expectedError = String.valueOf(expectedStatus.getStatus()); + mockService(l -> {}); + + var listener = doExecute(taskType, true); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + var ese = (ElasticsearchStatusException) e; + assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); + assertThat(ese.status(), is(expectedStatus)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterInferSuccess() { + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferSuccess() { + mockStreamResponse(Flow.Subscriber::onComplete); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferFailure() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockStreamResponse(subscriber -> { + subscriber.subscribe(mock()); + subscriber.onError(expectedException); + }); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamCancel() { + var response = mockStreamResponse(s -> s.onSubscribe(mock())); + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.cancel(); + } + + @Override + public void onNext(ChunkedToXContent item) { + + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + protected Flow.Publisher mockStreamResponse(Consumer> action) { + mockService(true, Set.of(), listener -> { + Flow.Processor taskProcessor = mock(); + doAnswer(innerAns -> { + action.accept(innerAns.getArgument(0)); + return null; + }).when(taskProcessor).subscribe(any()); + when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); + var inferenceServiceResults = mock(InferenceServiceResults.class); + when(inferenceServiceResults.publisher()).thenReturn(mock()); + listener.onResponse(inferenceServiceResults); + }); + + var listener = doExecute(taskType, true); + var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); + verify(listener).onResponse(captor.capture()); + assertTrue(captor.getValue().isStreaming()); + assertNotNull(captor.getValue().publisher()); + return captor.getValue().publisher(); + } + + protected void mockService(Consumer> listenerAction) { + mockService(false, Set.of(), listenerAction); + } + + protected void mockService( + boolean stream, + Set supportedStreamingTasks, + Consumer> listenerAction + ) { + InferenceService service = mock(); + Model model = mockModel(); + when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.name()).thenReturn(serviceId); + + when(service.canStream(any())).thenReturn(stream); + when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(7)); + return null; + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(3)); + return null; + }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); + mockModelAndServiceRegistry(service); + } + + protected Model mockModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn(serviceId); + when(model.getConfigurations()).thenReturn(modelConfigurations); + when(model.getTaskType()).thenReturn(taskType); + when(model.getServiceSettings()).thenReturn(mock()); + return model; + } + + protected void mockModelAndServiceRegistry(InferenceService service) { + var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 0ed9cbf56b3fa..e54175cb27009 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -7,66 +7,28 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.Flow; -import java.util.function.Consumer; - -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.isA; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.assertArg; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -public class TransportInferenceActionTests extends ESTestCase { - private static final String serviceId = "serviceId"; - private static final TaskType taskType = TaskType.COMPLETION; - private static final String inferenceId = "inferenceEntityId"; - private ModelRegistry modelRegistry; - private InferenceServiceRegistry serviceRegistry; - private InferenceStats inferenceStats; - private StreamingTaskManager streamingTaskManager; - private TransportInferenceAction action; +public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { - @Before - public void setUp() throws Exception { - super.setUp(); - TransportService transportService = mock(); - ActionFilters actionFilters = mock(); - modelRegistry = mock(); - serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); - streamingTaskManager = mock(); - action = new TransportInferenceAction( + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportInferenceAction( transportService, actionFilters, modelRegistry, @@ -76,279 +38,8 @@ public void setUp() throws Exception { ); } - public void testMetricsAfterModelRegistryError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onFailure(expectedException); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - var listener = doExecute(taskType); - verify(listener).onFailure(same(expectedException)); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), nullValue()); - assertThat(attributes.get("task_type"), nullValue()); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - private ActionListener doExecute(TaskType taskType) { - return doExecute(taskType, false); - } - - private ActionListener doExecute(TaskType taskType, boolean stream) { - InferenceAction.Request request = mock(); - when(request.getInferenceEntityId()).thenReturn(inferenceId); - when(request.getTaskType()).thenReturn(taskType); - when(request.isStreaming()).thenReturn(stream); - ActionListener listener = mock(); - action.doExecute(mock(), request, listener); - return listener; - } - - public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); - - when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); - - var listener = doExecute(taskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - private void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - } - - public void testMetricsAfterUnknownTaskType() { - var modelTaskType = TaskType.RERANK; - var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); - when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); - - var listener = doExecute(requestTaskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is( - "Incompatible task_type, the requested type [" - + requestTaskType - + "] does not match the model type [" - + modelTaskType - + "]" - ) - ); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(modelTaskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - public void testMetricsAfterInferError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockService(listener -> listener.onFailure(expectedException)); - - var listener = doExecute(taskType); - - verify(listener).onFailure(same(expectedException)); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamUnsupported() { - var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; - var expectedError = String.valueOf(expectedStatus.getStatus()); - mockService(l -> {}); - - var listener = doExecute(taskType, true); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - var ese = (ElasticsearchStatusException) e; - assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); - assertThat(ese.status(), is(expectedStatus)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterInferSuccess() { - mockService(listener -> listener.onResponse(mock())); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferSuccess() { - mockStreamResponse(Flow.Subscriber::onComplete); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferFailure() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockStreamResponse(subscriber -> { - subscriber.subscribe(mock()); - subscriber.onError(expectedException); - }); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamCancel() { - var response = mockStreamResponse(s -> s.onSubscribe(mock())); - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.cancel(); - } - - @Override - public void onNext(ChunkedToXContent item) { - - } - - @Override - public void onError(Throwable throwable) { - - } - - @Override - public void onComplete() { - - } - }); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - private Flow.Publisher mockStreamResponse(Consumer> action) { - mockService(true, Set.of(), listener -> { - Flow.Processor taskProcessor = mock(); - doAnswer(innerAns -> { - action.accept(innerAns.getArgument(0)); - return null; - }).when(taskProcessor).subscribe(any()); - when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); - var inferenceServiceResults = mock(InferenceServiceResults.class); - when(inferenceServiceResults.publisher()).thenReturn(mock()); - listener.onResponse(inferenceServiceResults); - }); - - var listener = doExecute(taskType, true); - var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); - verify(listener).onResponse(captor.capture()); - assertTrue(captor.getValue().isStreaming()); - assertNotNull(captor.getValue().publisher()); - return captor.getValue().publisher(); - } - - private void mockService(Consumer> listenerAction) { - mockService(false, Set.of(), listenerAction); - } - - private void mockService( - boolean stream, - Set supportedStreamingTasks, - Consumer> listenerAction - ) { - InferenceService service = mock(); - Model model = mockModel(); - when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); - when(service.name()).thenReturn(serviceId); - - when(service.canStream(any())).thenReturn(stream); - when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); - doAnswer(ans -> { - listenerAction.accept(ans.getArgument(7)); - return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - mockModelAndServiceRegistry(service); - } - - private Model mockModel() { - Model model = mock(); - ModelConfigurations modelConfigurations = mock(); - when(modelConfigurations.getService()).thenReturn(serviceId); - when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); - when(model.getServiceSettings()).thenReturn(mock()); - return model; - } - - private void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + @Override + protected InferenceAction.Request createRequest() { + return mock(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java new file mode 100644 index 0000000000000..4c943599ce523 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportUnifiedCompletionActionTests extends BaseTransportInferenceActionTestCase { + + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportUnifiedCompletionInferenceAction( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager + ); + } + + @Override + protected UnifiedCompletionAction.Request createRequest() { + return mock(); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { + var modelTaskType = TaskType.TEXT_EMBEDDING; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { + var modelTaskType = TaskType.ANY; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { + mockModelRegistry(TaskType.COMPLETION); + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(TaskType.ANY); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index d4ab9b1f1e19a..9e7c58b0ca79e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -61,25 +61,11 @@ public void testOneInputIsValid() { assertTrue("Test failed to call listener.", testRan.get()); } - public void testInvalidInputType() { - var badInput = mock(InferenceInputs.class); - var actualException = new AtomicReference(); - - executableAction.execute( - badInput, - mock(TimeValue.class), - ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set) - ); - - assertThat(actualException.get(), notNullValue()); - assertThat(actualException.get().getMessage(), is("Invalid inference input type")); - assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class)); - assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - } - public void testMoreThanOneInput() { var badInput = mock(DocumentsOnlyInput.class); - when(badInput.getInputs()).thenReturn(List.of("one", "two")); + var input = List.of("one", "two"); + when(badInput.getInputs()).thenReturn(input); + when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); executableAction.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index 87d3a82b4aae6..e7543aa6ba9e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; @@ -130,7 +131,7 @@ public void testCompletionRequestAction() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); @@ -163,7 +164,7 @@ public void testChatCompletionRequestAction_HandlesException() throws IOExceptio ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java index a3114300c5ddc..f0de37ceaaf98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -49,6 +49,7 @@ import static org.mockito.Mockito.mock; public class AnthropicActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -103,7 +104,7 @@ public void testCreate_ChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -168,7 +169,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java index fca2e316af17f..2065a726b7589 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,7 +149,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -170,7 +170,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -187,7 +187,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +229,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 8792234102a94..210fab457de10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -160,7 +161,7 @@ public void testChatCompletionRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 45a2fb0954c79..7e1e3e55caed8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -475,7 +476,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept var action = actionCreator.create(model, taskSettingsWithUserOverride); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -531,7 +532,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -589,7 +590,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 4c7683c882816..dca12dfda9c98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -111,7 +111,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction("resource", "deployment", "apiversion", user, apiKey, sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -142,7 +142,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -177,7 +177,7 @@ public void testExecute_ThrowsException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9ec34e7d8e5c5..3a512de25a39c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -197,7 +198,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -257,7 +258,7 @@ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOEx var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index ba839e0d7c5e9..c5871adb34864 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -26,8 +26,8 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; @@ -120,7 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -181,7 +181,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws var action = createAction(getUrl(webServer), "secret", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -214,7 +214,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -235,7 +235,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -256,7 +256,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -270,7 +270,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -284,7 +284,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -334,7 +334,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 72b5ffa45a0dd..ff17bbf66e02a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -128,7 +128,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -159,7 +159,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +180,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -260,7 +260,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b6d7eb673b7f0..fe076eb721ea2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -330,7 +331,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -345,11 +346,12 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -393,7 +395,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -408,10 +410,11 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(3)); + assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -455,7 +458,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -470,11 +473,12 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO assertNull(request.getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -523,7 +527,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -542,11 +546,12 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( assertNull(webServer.requests().get(0).getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index d84b2b5bb324a..ba74d2ab42c21 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; @@ -119,7 +119,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -134,11 +134,12 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -159,7 +160,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +181,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -215,7 +216,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +230,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -273,7 +274,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index e68beaf4c1eb5..929aefeeef6b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; @@ -67,8 +68,15 @@ public void send( ActionListener listener ) { sendCounter++; - var docsInput = (DocumentsOnlyInput) inferenceInputs; - inputs.add(docsInput.getInputs()); + if (inferenceInputs instanceof DocumentsOnlyInput docsInput) { + inputs.add(docsInput.getInputs()); + } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { + inputs.add(chatCompletionInput.getInputs()); + } else { + throw new IllegalArgumentException( + "Invalid inference inputs received in mock sender: " + inferenceInputs.getClass().getSimpleName() + ); + } if (results.isEmpty()) { listener.onFailure(new ElasticsearchException("No results found")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index 7fa8a09d5bf12..a8f37aedcece3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -107,7 +108,7 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws PlainActionFuture listener = new PlainActionFuture<>(); var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); - sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + sender.send(requestManager, new ChatCompletionInput(List.of("abc")), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java new file mode 100644 index 0000000000000..f0da67a982374 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class InferenceInputsTests extends ESTestCase { + public void testCastToSucceeds() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class)); + + var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); + assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); + assertThat( + new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + Matchers.instanceOf(QueryAndDocsInputs.class) + ); + } + + public void testCastToFails() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); + assertThat( + exception.getMessage(), + Matchers.containsString( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", DocumentsOnlyInput.class, QueryAndDocsInputs.class) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java new file mode 100644 index 0000000000000..42e1b18168aec --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class UnifiedChatInputTests extends ESTestCase { + + public void testConvertsStringInputToMessages() { + var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true); + + assertThat(a.inputSize(), Matchers.is(2)); + assertThat( + a.getRequest(), + Matchers.is( + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("hello"), + "a role", + null, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("awesome"), + "a role", + null, + null, + null + ) + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java new file mode 100644 index 0000000000000..0f127998f9c54 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,383 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.io.IOException; +import java.util.List; + +public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": null, + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool_call_id", + "function": { + "arguments": "example_arguments", + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 0 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 50, + "prompt_tokens": 20, + "total_tokens": 70 + } + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(50, chunk.getUsage().completionTokens()); + assertEquals(20, chunk.getUsage().promptTokens()); + assertEquals(70, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals("example_content", choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals("assistant", choice.delta().getRole()); + assertEquals("stop", choice.finishReason()); + assertEquals(0, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertEquals("tool_call_id", toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertEquals("example_arguments", toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testJsonLiteralCornerCases() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": null, + "refusal": null, + "role": "assistant", + "tool_calls": [] + }, + "finish_reason": null, + "index": 0 + }, + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "user", + "tool_calls": [ + { + "index": 1, + "function": { + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 1 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": null + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(2, choices.size()); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.delta().getContent()); + assertNull(firstChoice.delta().getRefusal()); + assertEquals("assistant", firstChoice.delta().getRole()); + assertTrue(firstChoice.delta().getToolCalls().isEmpty()); + assertNull(firstChoice.finishReason()); + assertEquals(0, firstChoice.index()); + + // Second choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); + assertEquals("example_content", secondChoice.delta().getContent()); + assertEquals("example_refusal", secondChoice.delta().getRefusal()); + assertEquals("user", secondChoice.delta().getRole()); + assertEquals("stop", secondChoice.finishReason()); + assertEquals(1, secondChoice.index()); + + List toolCalls = secondChoice.delta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertNull(toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertNull(toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { + // Generate random values for the JSON fields + int toolCallIndex = randomIntBetween(0, 10); + String toolCallId = randomAlphaOfLength(5); + String toolCallFunctionName = randomAlphaOfLength(8); + String toolCallFunctionArguments = randomAlphaOfLength(10); + String toolCallType = "function"; + String toolCallJson = createToolCallJson(toolCallIndex, toolCallId, toolCallFunctionName, toolCallFunctionArguments, toolCallType); + + String choiceContent = randomAlphaOfLength(10); + String choiceRole = randomFrom("system", "user", "assistant", "tool"); + String choiceFinishReason = randomFrom("stop", "length", "tool_calls", "content_filter", "function_call", null); + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(choiceContent, null, choiceRole, toolCallJson, choiceFinishReason, choiceIndex); + + int usageCompletionTokens = randomIntBetween(1, 100); + int usagePromptTokens = randomIntBetween(1, 100); + int usageTotalTokens = randomIntBetween(1, 200); + String usageJson = createUsageJson(usageCompletionTokens, usagePromptTokens, usageTotalTokens); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + usageJson + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens()); + assertEquals(usagePromptTokens, chunk.getUsage().promptTokens()); + assertEquals(usageTotalTokens, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals(choiceContent, choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals(choiceRole, choice.delta().getRole()); + assertEquals(choiceFinishReason, choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(toolCallIndex, toolCall.getIndex()); + assertEquals(toolCallId, toolCall.getId()); + assertEquals(toolCallFunctionName, toolCall.getFunction().getName()); + assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments()); + assertEquals(toolCallType, toolCall.getType()); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IOException { + // JSON with null fields + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(null, null, null, "", null, choiceIndex); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + null + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertNull(choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertNull(choice.delta().getRole()); + assertNull(choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + assertTrue(choice.delta().getToolCalls().isEmpty()); + } + } + + private String createToolCallJson(int index, String id, String functionName, String functionArguments, String type) { + return Strings.format(""" + { + "index": %d, + "id": "%s", + "function": { + "name": "%s", + "arguments": "%s" + }, + "type": "%s" + } + """, index, id, functionName, functionArguments, type); + } + + private String createChoiceJson(String content, String refusal, String role, String toolCallsJson, String finishReason, int index) { + if (role == null) { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } else { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "role": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + role != null ? "\"" + role + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } + } + + private String createChatCompletionChunkJson(String id, String choicesJson, String model, String object, String usageJson) { + if (usageJson != null) { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s", + "usage": %s + } + """, id, choicesJson, model, object, usageJson); + } else { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s" + } + """, id, choicesJson, model, object); + } + } + + private String createUsageJson(int completionTokens, int promptTokens, int totalTokens) { + return Strings.format(""" + { + "completion_tokens": %d, + "prompt_tokens": %d, + "total_tokens": %d + } + """, completionTokens, promptTokens, totalTokens); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java index 7ffa8940ad6be..065dfee577a82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -10,7 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; @@ -72,7 +72,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - private static DocumentsOnlyInput listOf(String... input) { - return new DocumentsOnlyInput(List.of(input)); + private static ChatCompletionInput listOf(String... input) { + return new ChatCompletionInput(List.of(input)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java deleted file mode 100644 index 9d5492f9e9516..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class OpenAiChatCompletionRequestEntityTests extends ESTestCase { - - public void testXContent_WritesUserWhenDefined() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", "user", false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1,"user":"user"}""")); - - } - - public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", null, false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1}""")); - } - - public void testXContent_ThrowsIfModelIsNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(List.of("abc"), null, "user", false)); - } - - public void testXContent_ThrowsIfMessagesAreNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(null, "model", "user", false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..f945c154ea234 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,856 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.equalTo; + +public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + // 1. Basic Serialization + // Test with minimal required fields to ensure basic serialization works. + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 2. Serialization with All Fields + // Test with all possible fields populated to ensure complete serialization. + public void testSerializationWithAllFields() throws IOException { + // Create a message with all fields populated + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + // Create a tool with all fields populated + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with all fields populated + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + // 3. Serialization with Null Optional Fields + // Test with optional fields set to null to ensure they are correctly omitted from the output. + public void testSerializationWithNullOptionalFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + // Create the unified request with optional fields set to null + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 4. Serialization with Empty Lists + // Test with fields that are lists set to empty lists to ensure they are correctly serialized. + public void testSerializationWithEmptyLists() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with empty lists + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 5. Serialization with Nested Objects + // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + // Generate random values + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + // Create a message with nested toolCalls + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + // Create a tool with nested function fields + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with nested objects + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + // Expected JSON should be dynamically generated based on random values + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + // 6. Serialization with Different Content Types + // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + // Generate random values for ContentString + String randomContentString = "Hello, world! " + random.nextInt(1000); + + // Generate random values for ContentObjects + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + // Create messages with different content types + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( + contentObjects, + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + // Create the unified request with both types of messages + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + // 7. Serialization with Special Characters + // Test with special characters in string fields to ensure they are correctly escaped and serialized. + public void testSerializationWithSpecialCharacters() throws IOException { + // Create a message with special characters + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 8. Serialization with Boolean Fields + // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. + public void testSerializationWithBooleanFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Test with stream set to true + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + // Test with stream set to false + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + // 9. Serialization with Missing Required Fields + // Test with missing required fields to ensure appropriate exceptions are thrown. + public void testSerializationWithMissingRequiredFields() { + // Create a message with missing content (required field) + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, // missing content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Attempt to serialize to XContent and expect an exception + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to missing required fields"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + // 10. Serialization with Mixed Valid and Invalid Data + // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. + public void testSerializationWithMixedValidAndInvalidData() throws IOException { + // Create a valid message + UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Valid content"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "validName", + "validToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "validId", + new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), + "validType" + ) + ) + ); + + // Create an invalid message with null content + UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( + null, // invalid content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "invalidName", + "invalidToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "invalidId", + new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), + "invalidType" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(validMessage); + messageList.add(invalidMessage); + // Create the unified request with both valid and invalid messages + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model-name", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList( + new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ) + ), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent and verify + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to invalid data"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } + + private void assertJsonEquals(String actual, String expected) throws IOException { + try ( + var actualParser = createParser(JsonXContent.jsonXContent, actual); + var expectedParser = createParser(JsonXContent.jsonXContent, expected) + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java similarity index 75% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index b6ebfd02941f3..2be12c9b12e0b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -20,16 +21,16 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest.buildDefaultUri; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionRequestTests extends ESTestCase { +public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { - var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); + var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -41,15 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertRequestMapWithUser(requestMap, "user"); + } + + private void assertRequestMapWithoutUser(Map requestMap) { + assertRequestMapWithUser(requestMap, null); + } + + private void assertRequestMapWithUser(Map requestMap, @Nullable String user) { + assertThat(requestMap, aMapWithSize(user != null ? 6 : 5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException { - var request = createRequest(null, "org", "secret", "abc", "model", "user"); + var request = createRequest(null, "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -61,33 +74,27 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithUser(requestMap, "user"); + } public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abc", "model", null); + var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithoutUser(requestMap); } - public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException { + public void testCreateRequest_WithStreaming() throws IOException { var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); @@ -99,29 +106,31 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep } public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap, aMapWithSize(5)); // We do not truncate for OpenAi chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); assertNull(request.getTruncationInfo()); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -132,7 +141,7 @@ public static OpenAiChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -142,7 +151,7 @@ public static OpenAiChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); + return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 05a8d52be5df4..5528c80066b0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.rest; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestRequestTests; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; @@ -26,6 +29,10 @@ import java.util.Map; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -56,6 +63,42 @@ private static String route(String param) { return "_route/" + param; } + public void testParseParams_ExtractsInferenceIdAndTaskType() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION))); + } + + public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY))); + } + + public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() { + var e = expectThrows( + ElasticsearchStatusException.class, + () -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id"))) + ); + assertThat(e.getMessage(), is("Task type must not be null")); + } + + public void testParseTimeout_ReturnsTimeout() { + var timeout = parseTimeout( + RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s")) + ); + + assertThat(timeout, is(TimeValue.timeValueSeconds(4))); + } + + public void testParseTimeout_ReturnsDefaultTimeout() { + var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of())); + + assertThat(timeout, is(TimeValue.timeValueSeconds(30))); + } + public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java new file mode 100644 index 0000000000000..5acfe67b175df --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); + + var request = (UnifiedCompletionAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/completion/test/_unified") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 47a96bf78dda1..6768583598b2d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.After; import org.junit.Before; @@ -119,6 +120,14 @@ protected void doInfer( } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) {} + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 76b5d6fee2c59..159b77789482d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -920,6 +921,68 @@ public void testInfer_SendsRequest() throws IOException { } } + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index ab1786f0a5843..e7ac4cf879e92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -10,9 +10,11 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; @@ -42,10 +44,48 @@ public void testOverrideWith_EmptyMap() { public void testOverrideWith_NullMap() { var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); - var overriddenModel = OpenAiChatCompletionModel.of(model, null); + var overriddenModel = OpenAiChatCompletionModel.of(model, (Map) null); assertThat(overriddenModel, sameInstance(model)); } + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "different_model", "user")) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "model_name", "user")) + ); + } + public static OpenAiChatCompletionModel createChatCompletionModel( String url, @Nullable String org, diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 8df10037affdb..c91314716cf9e 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -386,6 +386,7 @@ public class Constants { "cluster:monitor/xpack/esql/stats/dist", "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", + "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info",