From 4badd50ad5d68e799b3b8912a1e0a061f3e0e414 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 26 Mar 2021 13:23:17 +0200 Subject: [PATCH] [ML] Infer against model deployment This adds a temporary API for doing inference against a trained model deployment. --- .../elasticsearch/xpack/core/ml/MlTasks.java | 4 +- .../InferTrainedModelDeploymentAction.java | 133 ++++++++++++++++++ .../inference/deployment/PyTorchResult.java | 129 +++++++++++++++++ .../preprocessing/CustomWordEmbedding.java | 28 +--- .../xpack/core/ml/utils/MlParserUtils.java | 53 +++++++ .../xpack/ml/MachineLearning.java | 5 + ...portInferTrainedModelDeploymentAction.java | 79 +++++++++++ ...sportStopTrainedModelDeploymentAction.java | 6 +- .../deployment/DeploymentManager.java | 39 ++++- .../TrainedModelDeploymentTask.java | 6 + .../pytorch/process/NativePyTorchProcess.java | 38 +++++ .../pytorch/process/PyTorchProcess.java | 9 ++ .../process/PyTorchResultProcessor.java | 80 +++++++++++ .../ml/process/AbstractNativeProcess.java | 2 +- ...RestInferTrainedModelDeploymentAction.java | 55 ++++++++ 15 files changed, 632 insertions(+), 34 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index a98f444accc27..cd76959a1c803 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -107,8 +107,8 @@ public static PersistentTasksCustomMetadata.PersistentTask getSnapshotUpgrade } @Nullable - public static PersistentTasksCustomMetadata.PersistentTask getDeployTrainedModelTask(String modelId, - @Nullable PersistentTasksCustomMetadata tasks) { + public static PersistentTasksCustomMetadata.PersistentTask getTrainedModelDeploymentTask( + String modelId, @Nullable PersistentTasksCustomMetadata tasks) { return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..f18d10ef3cacf --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class InferTrainedModelDeploymentAction extends ActionType { + + public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction(); + + // TODO Review security level + public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer"; + + public InferTrainedModelDeploymentAction() { + super(NAME, InferTrainedModelDeploymentAction.Response::new); + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + private static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id"); + public static final ParseField INPUTS = new ParseField("inputs"); + + private static final ObjectParser PARSER = new ObjectParser<>("infer_trained_model_request", Request::new); + + static { + PARSER.declareString((request, deploymentId) -> request.deploymentId = deploymentId, DEPLOYMENT_ID); + PARSER.declareDoubleArray(Request::setInputs, INPUTS); + } + + public static Request parseRequest(String deploymentId, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (deploymentId != null) { + request.deploymentId = deploymentId; + } + return request; + } + + private String deploymentId; + private double[] inputs; + + private Request() { + } + + public Request(String deploymentId) { + this.deploymentId = Objects.requireNonNull(deploymentId); + } + + public Request(StreamInput in) throws IOException { + super(in); + deploymentId = in.readString(); + inputs = in.readDoubleArray(); + } + + public String getDeploymentId() { + return deploymentId; + } + + public void setInputs(List inputs) { + ExceptionsHelper.requireNonNull(inputs, INPUTS); + this.inputs = inputs.stream().mapToDouble(d -> d).toArray(); + } + + public double[] getInputs() { + return inputs; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(deploymentId); + out.writeDoubleArray(inputs); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId); + builder.array(INPUTS.getPreferredName(), inputs); + builder.endObject(); + return builder; + } + + @Override + public boolean match(Task task) { + return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId); + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final PyTorchResult result; + + public Response(PyTorchResult result) { + super(Collections.emptyList(), Collections.emptyList()); + this.result = Objects.requireNonNull(result); + } + + public Response(StreamInput in) throws IOException { + super(in); + result = new PyTorchResult(in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + result.toXContent(builder, params); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java new file mode 100644 index 0000000000000..3d776ee7a41c2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/* + * TODO This does not necessarily belong in core. Will have to reconsider + * once we figure the format we store inference results in client calls. +*/ +public class PyTorchResult implements ToXContentObject, Writeable { + + private static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField INFERENCE = new ParseField("inference"); + private static final ParseField ERROR = new ParseField("error"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("pytorch_result", + a -> new PyTorchResult((String) a[0], (double[][]) a[1], (String) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID); + PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + List> listOfListOfDoubles = MlParserUtils.parseArrayOfArrays( + INFERENCE.getPreferredName(), XContentParser::doubleValue, p); + double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][]; + for (int i = 0; i < listOfListOfDoubles.size(); i++) { + List row = listOfListOfDoubles.get(i); + double[] primitiveRow = new double[row.size()]; + for (int j = 0; j < row.size(); j++) { + primitiveRow[j] = row.get(j); + } + primitiveDoubles[i] = primitiveRow; + } + return primitiveDoubles; + }, + INFERENCE, + ObjectParser.ValueType.VALUE_ARRAY + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ERROR); + } + + private final String requestId; + private final double[][] inference; + private final String error; + + public PyTorchResult(String requestId, @Nullable double[][] inference, @Nullable String error) { + this.requestId = Objects.requireNonNull(requestId); + this.inference = inference; + this.error = error; + } + + public PyTorchResult(StreamInput in) throws IOException { + requestId = in.readString(); + boolean hasInference = in.readBoolean(); + if (hasInference) { + inference = in.readArray(StreamInput::readDoubleArray, length -> new double[length][]); + } else { + inference = null; + } + error = in.readOptionalString(); + } + + public String getRequestId() { + return requestId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(REQUEST_ID.getPreferredName(), requestId); + if (inference != null) { + builder.field(INFERENCE.getPreferredName(), inference); + } + if (error != null) { + builder.field(ERROR.getPreferredName(), error); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(requestId); + if (inference == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeArray(StreamOutput::writeDoubleArray, inference); + } + out.writeOptionalString(error); + } + + @Override + public int hashCode() { + return Objects.hash(requestId, Arrays.hashCode(inference), error); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + + PyTorchResult that = (PyTorchResult) other; + return Objects.equals(requestId, that.requestId) + && Objects.equals(inference, that.inference) + && Objects.equals(error, that.error); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 3e42b397f1089..1ddedec7cb7c0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.preprocessing; import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; +import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; import java.io.IOException; import java.util.ArrayList; @@ -63,7 +63,7 @@ private static ConstructingObjectParser { - List> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), + List> listOfListOfShorts = MlParserUtils.parseArrayOfArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), XContentParser::shortValue, p); short[][] primitiveShorts = new short[listOfListOfShorts.size()][]; @@ -99,30 +99,6 @@ private static ConstructingObjectParser List> parseArrays(String fieldName, - CheckedFunction fromParser, - XContentParser p) throws IOException { - if (p.currentToken() != XContentParser.Token.START_ARRAY) { - throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]"); - } - List> values = new ArrayList<>(); - while(p.nextToken() != XContentParser.Token.END_ARRAY) { - if (p.currentToken() != XContentParser.Token.START_ARRAY) { - throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]"); - } - List innerList = new ArrayList<>(); - while(p.nextToken() != XContentParser.Token.END_ARRAY) { - if(p.currentToken().isValue() == false) { - throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " + - "for [" + fieldName + "]"); - } - innerList.add(fromParser.apply(p)); - } - values.add(innerList); - } - return values; - } - public static CustomWordEmbedding fromXContentStrict(XContentParser parser) { return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java new file mode 100644 index 0000000000000..3afc3db3d1ead --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public final class MlParserUtils { + + private MlParserUtils() {} + + /** + * Parses an array of arrays of the given type + * + * @param fieldName the field name + * @param valueParser the parser to use for the inner array values + * @param parser the outer parser + * @param the type of the values of the inner array + * @return a list of lists representing the array of arrays + * @throws IOException an exception if parsing fails + */ + public static List> parseArrayOfArrays(String fieldName, CheckedFunction valueParser, + XContentParser parser) throws IOException { + if (parser.currentToken() != XContentParser.Token.START_ARRAY) { + throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); + } + List> values = new ArrayList<>(); + while(parser.nextToken() != XContentParser.Token.END_ARRAY) { + if (parser.currentToken() != XContentParser.Token.START_ARRAY) { + throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); + } + List innerList = new ArrayList<>(); + while(parser.nextToken() != XContentParser.Token.END_ARRAY) { + if(parser.currentToken().isValue() == false) { + throw new IllegalStateException("expected non-null value but got [" + parser.currentToken() + "] " + + "for [" + fieldName + "]"); + } + innerList.add(valueParser.apply(parser)); + } + values.add(innerList); + } + return values; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 7bb7f6a86234e..d8213e709f540 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -97,6 +97,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; @@ -182,6 +183,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; @@ -338,6 +340,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; @@ -981,6 +984,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestPreviewDataFrameAnalyticsAction(), new RestStartTrainedModelDeploymentAction(), new RestStopTrainedModelDeploymentAction(), + new RestInferTrainedModelDeploymentAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1070,6 +1074,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(SetResetModeAction.INSTANCE, TransportSetResetModeAction.class), new ActionHandler<>(StartTrainedModelDeploymentAction.INSTANCE, TransportStartTrainedModelDeploymentAction.class), new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class), + new ActionHandler<>(InferTrainedModelDeploymentAction.INSTANCE, TransportInferTrainedModelDeploymentAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..c3b31f8e62ebe --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; + +import java.util.List; + +public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction { + + @Inject + public TransportInferTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters) { + super(InferTrainedModelDeploymentAction.NAME, clusterService, transportService, actionFilters, + InferTrainedModelDeploymentAction.Request::new, InferTrainedModelDeploymentAction.Response::new, + InferTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); + } + + @Override + protected void doExecute(Task task, InferTrainedModelDeploymentAction.Request request, + ActionListener listener) { + String deploymentId = request.getDeploymentId(); + // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the + // node running the job task. + PersistentTasksCustomMetadata tasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + PersistentTasksCustomMetadata.PersistentTask deploymentTask = MlTasks.getTrainedModelDeploymentTask(deploymentId, tasks); + if (deploymentTask == null || deploymentTask.isAssigned() == false) { + String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not started"; + listener.onFailure(ExceptionsHelper.conflictStatusException(message)); + } else { + request.setNodes(deploymentTask.getExecutorNode()); + super.doExecute(task, request, listener); + } + } + + @Override + protected InferTrainedModelDeploymentAction.Response newResponse(InferTrainedModelDeploymentAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + return tasks.get(0); + } + } + + @Override + protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, + ActionListener listener) { + task.infer(request.getInputs(), + ActionListener.wrap( + pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)), + listener::onFailure) + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 6d9722eef9aa5..d9f586905e913 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -53,10 +53,10 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct private final PersistentTasksService persistentTasksService; @Inject - public TransportStopTrainedModelDeploymentAction(String actionName, ClusterService clusterService, TransportService transportService, + public TransportStopTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, Client client, ThreadPool threadPool, PersistentTasksService persistentTasksService) { - super(actionName, clusterService, transportService, actionFilters, StopTrainedModelDeploymentAction.Request::new, + super(StopTrainedModelDeploymentAction.NAME, clusterService, transportService, actionFilters, StopTrainedModelDeploymentAction.Request::new, StopTrainedModelDeploymentAction.Response::new, StopTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); this.client = client; this.threadPool = threadPool; @@ -90,7 +90,7 @@ protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request req ClusterState clusterState = clusterService.state(); PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = - MlTasks.getDeployTrainedModelTask(request.getId(), tasks); + MlTasks.getTrainedModelDeploymentTask(request.getId(), tasks); if (deployTrainedModelTask == null) { listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); return; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 1425266ef9461..35a90bb092006 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -15,13 +15,16 @@ import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor; import java.io.IOException; import java.util.Map; @@ -69,6 +72,8 @@ private void doStartDeployment(TrainedModelDeploymentTask task) { logger.error(new ParameterizedMessage("[{}] error loading model", task.getModelId()), e); } + executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get())); + TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); task.updatePersistentTaskState(startedState, ActionListener.wrap( @@ -84,26 +89,56 @@ public void stopDeployment(TrainedModelDeploymentTask task) { } if (processContext != null) { logger.debug("[{}] Stopping deployment", task.getModelId()); - processContext.killProcess(); + processContext.stopProcess(); } else { logger.debug("[{}] No process context to stop", task.getModelId()); } } + public void infer(TrainedModelDeploymentTask task, double[] inputs, ActionListener listener) { + ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); + try { + String requestId = processContext.process.get().writeInferenceRequest(inputs); + waitForResult(processContext, requestId, listener); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e); + listener.onFailure(ExceptionsHelper.serverError("error writing to process", e)); + return; + } + } + + private void waitForResult(ProcessContext processContext, String requestId, ActionListener listener) { + try { + // TODO the timeout value should come from the action + TimeValue timeout = TimeValue.timeValueSeconds(5); + PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(requestId, timeout); + if (pyTorchResult == null) { + listener.onFailure(ExceptionsHelper.serverError("no result was produced within timeout value [{}]", timeout)); + } else { + listener.onResponse(pyTorchResult); + } + } catch (InterruptedException e) { + listener.onFailure(e); + } + } + class ProcessContext { private final String modelId; private final SetOnce process = new SetOnce<>(); + private final PyTorchResultProcessor resultProcessor; ProcessContext(String modelId) { this.modelId = Objects.requireNonNull(modelId); + resultProcessor = new PyTorchResultProcessor(modelId); } synchronized void startProcess() { process.set(pyTorchProcessFactory.createProcess(modelId, executorServiceForProcess, onProcessCrash())); } - synchronized void killProcess() { + synchronized void stopProcess() { + resultProcessor.stop(); if (process.get() == null) { return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index aa4f7f8d93998..b3ccf32fa37fa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -9,11 +9,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; import java.util.Map; @@ -53,4 +55,8 @@ protected void onCancelled() { String reason = getReasonCancelled(); stop(reason); } + + public void infer(double[] inputs, ActionListener listener) { + manager.infer(this, inputs, listener); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java index fdb65304f5571..2347e657049ec 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java @@ -7,26 +7,41 @@ package org.elasticsearch.xpack.ml.inference.pytorch.process; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.util.Arrays; import java.util.Base64; +import java.util.Iterator; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; public class NativePyTorchProcess extends AbstractNativeProcess implements PyTorchProcess { + private static final Logger logger = LogManager.getLogger(NativePyTorchProcess.class); + private static final String NAME = "pytorch_inference"; + private static AtomicLong ms_RequestId = new AtomicLong(1); + + private final ProcessResultsParser resultsParser; + protected NativePyTorchProcess(String jobId, NativeController nativeController, ProcessPipes processPipes, int numberOfFields, List filesToDelete, Consumer onProcessCrash) { super(jobId, nativeController, processPipes, numberOfFields, filesToDelete, onProcessCrash); + this.resultsParser = new ProcessResultsParser<>(PyTorchResult.PARSER, NamedXContentRegistry.EMPTY); } @Override @@ -53,4 +68,27 @@ public void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOE restoreStream.write(modelBytes); } } + + @Override + public Iterator readResults() { + return resultsParser.parseResults(processOutStream()); + } + + @Override + public String writeInferenceRequest(double[] inputs) throws IOException { + long requestId = ms_RequestId.getAndIncrement(); + String json = new StringBuilder("{") + .append("\"request_id\":\"") + .append(requestId) + .append("\",") + .append("\"inputs\":") + .append(Arrays.toString(inputs)) + .append("}\n") + .toString(); + + processInStream().write(json.getBytes(StandardCharsets.UTF_8)); + processInStream().flush(); + + return String.valueOf(requestId); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java index 72c3e9b8af0d8..dacf466ff58f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java @@ -7,11 +7,20 @@ package org.elasticsearch.xpack.ml.inference.pytorch.process; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; import org.elasticsearch.xpack.ml.process.NativeProcess; import java.io.IOException; +import java.util.Iterator; public interface PyTorchProcess extends NativeProcess { void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOException; + + Iterator readResults(); + + /** + * Writes an inference request to the process and returns the request id + */ + String writeInferenceRequest(double[] inputs) throws IOException; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java new file mode 100644 index 0000000000000..7e1ff9d9fbdf2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; + +import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class PyTorchResultProcessor { + + private static final Logger logger = LogManager.getLogger(PyTorchResultProcessor.class); + + private final ConcurrentMap pendingResults = new ConcurrentHashMap<>(); + + private final String deploymentId; + private volatile boolean isStopping; + + public PyTorchResultProcessor(String deploymentId) { + this.deploymentId = Objects.requireNonNull(deploymentId); + } + + public void process(PyTorchProcess process) { + try { + Iterator iterator = process.readResults(); + while (iterator.hasNext()) { + PyTorchResult result = iterator.next(); + logger.debug(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, result.getRequestId())); + PendingResult pendingResult = pendingResults.get(result.getRequestId()); + if (pendingResult == null) { + logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); + } else { + pendingResult.result = result; + pendingResult.latch.countDown(); + } + } + } catch (Exception e) { + if (isStopping) { + // No need to report error as we're stopping + } else { + logger.error(new ParameterizedMessage("[{}] Error processing results", deploymentId), e); + } + } + logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId)); + } + + public PyTorchResult waitForResult(String requestId, TimeValue timeout) throws InterruptedException { + PendingResult pendingResult = pendingResults.computeIfAbsent(requestId, k -> new PendingResult()); + try { + if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) { + return pendingResult.result; + } + } finally { + pendingResults.remove(requestId); + } + return null; + } + + public void stop() { + isStopping = true; + } + + private static class PendingResult { + private volatile PyTorchResult result; + private final CountDownLatch latch = new CountDownLatch(1); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java index 0c640e8dd7ebe..faa194f2be470 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java @@ -288,7 +288,7 @@ protected InputStream processOutStream() { } @Nullable - private OutputStream processInStream() { + protected OutputStream processInStream() { return processInStream.get(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..a1d711f245359 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestInferTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_infer_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route( + POST, + BASE_PATH + "trained_models/deployment/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_infer") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String deploymentId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + InferTrainedModelDeploymentAction.Request request; + if (restRequest.hasContentOrSourceParam()) { + XContentParser parser = restRequest.contentOrSourceParamParser(); + request = InferTrainedModelDeploymentAction.Request.parseRequest(deploymentId, parser); + } else { + throw ExceptionsHelper.badRequestException("requires body"); + } + + return channel -> client.execute(InferTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +}