From af9f6082490dfbb60f7e09c0aef4488fdac639da Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 9 Mar 2021 14:07:19 +0200 Subject: [PATCH 1/7] [ML] Start and stop model deployments --- .../elasticsearch/xpack/core/ml/MlTasks.java | 12 ++ .../ml/action/DeployTrainedModelAction.java | 170 +++++++++++++++ .../ml/action/UndeployTrainedModelAction.java | 161 ++++++++++++++ .../deployment/DeployTrainedModelState.java | 38 ++++ .../DeployTrainedModelTaskState.java | 104 +++++++++ .../xpack/ml/MachineLearning.java | 30 ++- .../TransportDeployTrainedModelAction.java | 201 ++++++++++++++++++ .../TransportUndeployTrainedModelAction.java | 200 +++++++++++++++++ .../deployment/DeployTrainedModelTask.java | 56 +++++ .../deployment/DeploymentManager.java | 95 +++++++++ .../pytorch/process/NativePyTorchProcess.java | 42 ++++ .../process/NativePyTorchProcessFactory.java | 103 +++++++++ .../pytorch/process/PyTorchBuilder.java | 49 +++++ .../pytorch/process/PyTorchProcess.java | 14 ++ .../process/PyTorchProcessFactory.java | 16 ++ .../process/PyTorchProcessManager.java | 24 +++ .../RestDeployTrainedModelAction.java | 44 ++++ .../RestUndeployTrainedModelAction.java | 46 ++++ 18 files changed, 1404 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index a495cbe9c77cc..239638cd35205 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -28,11 +28,13 @@ public final class MlTasks { public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade"; + public static final String DEPLOY_TRAINED_MODEL_TASK_NAME = "xpack/ml/trained_models/deploy"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_ID_PREFIX = "job-snapshot-upgrade-"; + public static final String DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX = "deploy_trained_model-"; public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetadata.Assignment(null, @@ -76,6 +78,10 @@ public static String dataFrameAnalyticsId(String taskId) { return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length()); } + public static String deployTrainedModelTaskId(String modelId) { + return DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + modelId; + } + @Nullable public static PersistentTasksCustomMetadata.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetadata tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -100,6 +106,12 @@ public static PersistentTasksCustomMetadata.PersistentTask getSnapshotUpgrade return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId)); } + @Nullable + public static PersistentTasksCustomMetadata.PersistentTask getDeployTrainedModelTask(String modelId, + @Nullable PersistentTasksCustomMetadata tasks) { + return tasks == null ? null : tasks.getTask(deployTrainedModelTaskId(modelId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java new file mode 100644 index 0000000000000..c1a6e7fc1ddf9 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeployTrainedModelAction extends ActionType { + + public static final DeployTrainedModelAction INSTANCE = new DeployTrainedModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/trained_model/deploy"; + + public DeployTrainedModelAction() { + super(NAME, NodeAcknowledgedResponse::new); + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + private static final ParseField MODEL_ID = new ParseField("model_id"); + + private String modelId; + + public Request(String modelId) { + setModelId(modelId); + } + + public Request(StreamInput in) throws IOException { + super(in); + modelId = in.readString(); + } + + public final void setModelId(String modelId) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + } + + public String getModelId() { + return modelId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(modelId, other.modelId); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class TaskParams implements PersistentTaskParams { + + public static final Version VERSION_INTRODUCED = Version.V_7_13_0; + + private final String modelId; + + public TaskParams(String modelId) { + this.modelId = Objects.requireNonNull(modelId); + } + + public TaskParams(StreamInput in) throws IOException { + this.modelId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return VERSION_INTRODUCED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + TaskParams other = (TaskParams) o; + return Objects.equals(modelId, other.modelId); + } + } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (Strings.isAllOrWildcard(expectedId)) { + return true; + } + String expectedDescription = MlTasks.DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java new file mode 100644 index 0000000000000..76d1e690e900d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class UndeployTrainedModelAction extends ActionType { + + public static final UndeployTrainedModelAction INSTANCE = new UndeployTrainedModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/undeploy"; + + public UndeployTrainedModelAction() { + super(NAME, UndeployTrainedModelAction.Response::new); + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + public static final ParseField FORCE = new ParseField("force"); + + private String id; + private boolean allowNoMatch = true; + private boolean force; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + allowNoMatch = in.readBoolean(); + force = in.readBoolean(); + } + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setForce(boolean force) { + this.force = force; + } + + public boolean isForce() { + return force; + } + + @Override + public boolean match(Task task) { + return DeployTrainedModelAction.TaskMatcher.match(task, id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeBoolean(allowNoMatch); + out.writeBoolean(force); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch); + builder.field(FORCE.getPreferredName(), force); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, allowNoMatch, force); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Request that = (Request) o; + return Objects.equals(id, that.id) && + allowNoMatch == that.allowNoMatch && + force == that.force; + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean undeployed; + + public Response(boolean undeployed) { + super(null, null); + this.undeployed = undeployed; + } + + public Response(StreamInput in) throws IOException { + super(in); + undeployed = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(undeployed); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("undeployed", undeployed); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(undeployed); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return undeployed == that.undeployed; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java new file mode 100644 index 0000000000000..03c136e0dcee3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum DeployTrainedModelState implements Writeable { + + DEPLOYING, DEPLOYED, UNDEPLOYING, UNDEPLOYED; + + public static DeployTrainedModelState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static DeployTrainedModelState fromStream(StreamInput in) throws IOException { + return in.readEnum(DeployTrainedModelState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java new file mode 100644 index 0000000000000..d98dc2c0fa14f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.io.IOException; +import java.util.Objects; + +public class DeployTrainedModelTaskState implements PersistentTaskState { + + public static final String NAME = MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME; + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + private static ParseField REASON = new ParseField("reason"); + + private final DeployTrainedModelState state; + private final long allocationId; + private final String reason; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, + a -> new DeployTrainedModelTaskState((DeployTrainedModelState) a[0], (long) a[1], (String) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsState::fromString, STATE); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); + } + + public static DeployTrainedModelTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public DeployTrainedModelTaskState(DeployTrainedModelState state, long allocationId, @Nullable String reason) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + this.reason = reason; + } + + public DeployTrainedModelTaskState(StreamInput in) throws IOException { + this.state = DeployTrainedModelState.fromStream(in); + this.allocationId = in.readLong(); + this.reason = in.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + if (reason != null) { + builder.field(REASON.getPreferredName(), reason); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + out.writeOptionalString(reason); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeployTrainedModelTaskState that = (DeployTrainedModelTaskState) o; + return allocationId == that.allocationId && + state == that.state && + Objects.equals(reason, that.reason); + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId, reason); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 9b36d75daedee..38b3d0af7f8e6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; @@ -141,6 +142,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -157,6 +159,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -175,6 +178,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.action.TransportDeployTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportExplainDataFrameAnalyticsAction; @@ -221,6 +225,7 @@ import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; +import org.elasticsearch.xpack.ml.action.TransportUndeployTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDatafeedAction; @@ -252,10 +257,13 @@ import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -325,10 +333,12 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestDeployTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestUndeployTrainedModelAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -531,6 +541,7 @@ public Set getRoles() { private final SetOnce inferenceModelBreaker = new SetOnce<>(); private final SetOnce modelLoadingService = new SetOnce<>(); private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); + private final SetOnce deploymentManager = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -693,6 +704,7 @@ public Collection createComponents(Client client, ClusterService cluster final NormalizerProcessFactory normalizerProcessFactory; final AnalyticsProcessFactory analyticsProcessFactory; final AnalyticsProcessFactory memoryEstimationProcessFactory; + final PyTorchProcessFactory pyTorchProcessFactory; if (MachineLearningField.AUTODETECT_PROCESS.get(settings)) { try { NativeController nativeController = @@ -714,6 +726,7 @@ public Collection createComponents(Client client, ClusterService cluster dataFrameAnalyticsAuditor); memoryEstimationProcessFactory = new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService); + pyTorchProcessFactory = new NativePyTorchProcessFactory(environment, nativeController, clusterService); mlController = nativeController; } catch (IOException e) { // The low level cause of failure from the named pipe helper's perspective is almost never the real root cause, so @@ -733,6 +746,7 @@ public Collection createComponents(Client client, ClusterService cluster normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; + pyTorchProcessFactory = (jobId, executorService, onProcessCrash) -> null; } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); @@ -773,6 +787,7 @@ public Collection createComponents(Client client, ClusterService cluster clusterService.getNodeName(), inferenceModelBreaker.get()); this.modelLoadingService.set(modelLoadingService); + this.deploymentManager.set(new DeploymentManager(threadPool, pyTorchProcessFactory)); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( @@ -877,7 +892,12 @@ public List> getPersistentTasksExecutor(ClusterServic autodetectProcessManager.get(), memoryTracker.get(), expressionResolver, - client) + client), + new TransportDeployTrainedModelAction.TaskExecutor(settings, + clusterService, + expressionResolver, + memoryTracker.get(), + deploymentManager.get()) ); } @@ -953,6 +973,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutTrainedModelAliasAction(), new RestDeleteTrainedModelAliasAction(), new RestPreviewDataFrameAnalyticsAction(), + new RestDeployTrainedModelAction(), + new RestUndeployTrainedModelAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1039,6 +1061,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class), new ActionHandler<>(DeleteTrainedModelAliasAction.INSTANCE, TransportDeleteTrainedModelAliasAction.class), new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class), + new ActionHandler<>(DeployTrainedModelAction.INSTANCE, TransportDeployTrainedModelAction.class), + new ActionHandler<>(UndeployTrainedModelAction.INSTANCE, TransportUndeployTrainedModelAction.class), usageAction, infoAction); } @@ -1173,6 +1197,8 @@ public List getNamedWriteables() { StartDataFrameAnalyticsAction.TaskParams::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, SnapshotUpgradeTaskParams::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, + DeployTrainedModelAction.TaskParams::new)); // Persistent task states namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new)); @@ -1182,6 +1208,8 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, SnapshotUpgradeTaskState.NAME, SnapshotUpgradeTaskState::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, + DeployTrainedModelTaskState.NAME, DeployTrainedModelTaskState::new)); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java new file mode 100644 index 0000000000000..aff8b142d7f1b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java @@ -0,0 +1,201 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction.TaskParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.DeployTrainedModelTask; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.job.JobNodeSelector; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +public class TransportDeployTrainedModelAction + extends TransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportDeployTrainedModelAction.class); + + private final XPackLicenseState licenseState; + private final Client client; + private final PersistentTasksService persistentTasksService; + + @Inject + public TransportDeployTrainedModelAction(TransportService transportService, Client client, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, + IndexNameExpressionResolver indexNameExpressionResolver, + PersistentTasksService persistentTasksService) { + super(DeployTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, + DeployTrainedModelAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, ThreadPool.Names.SAME); + this.licenseState = Objects.requireNonNull(licenseState); + this.client = Objects.requireNonNull(client); + this.persistentTasksService = Objects.requireNonNull(persistentTasksService); + } + + @Override + protected void masterOperation(Task task, DeployTrainedModelAction.Request request, ClusterState state, + ActionListener listener) throws Exception { + logger.debug(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); + if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + ActionListener getModelListener = ActionListener.wrap( + getModelResponse -> { + if (getModelResponse.getResources().results().size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException( + "cannot deploy more than one models at the same time; [{}] matches [{}] models]", + request.getModelId(), getModelResponse.getResources().results().size())); + return; + } + persistentTasksService.sendStartRequest( + MlTasks.deployTrainedModelTaskId(request.getModelId()), + MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, + new TaskParams(request.getModelId()), + ActionListener.wrap( + response -> listener.onResponse(new NodeAcknowledgedResponse(true, "")), + listener::onFailure + ) + ); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request( + request.getModelId(), null, Collections.emptySet()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + @Override + protected ClusterBlockException checkBlock(DeployTrainedModelAction.Request request, ClusterState state) { + // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, + // then delegating to PersistentTasksService doesn't make a whole lot of sense, + // because PersistentTasksService will then fail. + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + public static class TaskExecutor extends AbstractJobPersistentTasksExecutor { + + private final DeploymentManager manager; + + public TaskExecutor(Settings settings, ClusterService clusterService, IndexNameExpressionResolver expressionResolver, + MlMemoryTracker memoryTracker, DeploymentManager manager) { + super(MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, + MachineLearning.UTILITY_THREAD_POOL_NAME, + settings, + clusterService, + memoryTracker, + expressionResolver); + this.manager = Objects.requireNonNull(manager); + } + + @Override + protected AllocatedPersistentTask createTask( + long id, String type, String action, TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask persistentTask, + Map headers) { + return new DeployTrainedModelTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + } + + @Override + public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params, ClusterState clusterState) { + JobNodeSelector jobNodeSelector = + new JobNodeSelector( + clusterState, + params.getModelId(), + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + memoryTracker, + 0, + node -> nodeFilter(node, params)); + PersistentTasksCustomMetadata.Assignment assignment = jobNodeSelector.selectNode( + maxOpenJobs, + Integer.MAX_VALUE, + maxMachineMemoryPercent, + maxNodeMemory, + false, + useAutoMemoryPercentage + ); + return assignment; + } + + public static String nodeFilter(DiscoveryNode node, TaskParams params) { + String id = params.getModelId(); + + if (node.getVersion().before(TaskParams.VERSION_INTRODUCED)) { + return "Not opening job [" + id + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because the data frame analytics requires a node of version [" + + TaskParams.VERSION_INTRODUCED + "] or higher"; + } + + return null; + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) { + DeployTrainedModelTask deployTrainedModelTask = (DeployTrainedModelTask) task; + deployTrainedModelTask.setDeploymentManager(manager); + + DeployTrainedModelTaskState deployingState = new DeployTrainedModelTaskState( + DeployTrainedModelState.DEPLOYING, task.getAllocationId(), null); + task.updatePersistentTaskState(deployingState, ActionListener.wrap( + response -> manager.deployModel(deployTrainedModelTask), + task::markAsFailed + )); + } + + @Override + protected String[] indicesOfInterest(TaskParams params) { + return new String[] { + InferenceIndexConstants.INDEX_PATTERN + }; + } + + @Override + protected String getJobId(TaskParams params) { + return params.getModelId(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java new file mode 100644 index 0000000000000..d7bcca5739cfd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.discovery.MasterNotDiscoveredException; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.DeployTrainedModelTask; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +public class TransportUndeployTrainedModelAction extends TransportTasksAction { + + private static final Logger logger = LogManager.getLogger(TransportUndeployTrainedModelAction.class); + + private final Client client; + private final ThreadPool threadPool; + private final PersistentTasksService persistentTasksService; + + @Inject + public TransportUndeployTrainedModelAction(String actionName, ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters, Client client, ThreadPool threadPool, + PersistentTasksService persistentTasksService) { + super(actionName, clusterService, transportService, actionFilters, UndeployTrainedModelAction.Request::new, + UndeployTrainedModelAction.Response::new, UndeployTrainedModelAction.Response::new, ThreadPool.Names.SAME); + this.client = client; + this.threadPool = threadPool; + this.persistentTasksService = persistentTasksService; + } + + @Override + protected void doExecute(Task task, UndeployTrainedModelAction.Request request, + ActionListener listener) { + ClusterState state = clusterService.state(); + DiscoveryNodes nodes = state.nodes(); + if (nodes.isLocalNodeElectedMaster() == false) { + redirectToMasterNode(nodes.getMasterNode(), request, listener); + return; + } + + logger.debug("[{}] Received request to undeploy", request.getId()); + + ActionListener getModelListener = ActionListener.wrap( + getModelsResponse -> { + List models = getModelsResponse.getResources().results(); + if (models.isEmpty()) { + listener.onResponse(new UndeployTrainedModelAction.Response(true)); + return; + } + if (models.size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException("cannot undeploy multiple models at the same time")); + return; + } + + ClusterState clusterState = clusterService.state(); + PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = + MlTasks.getDeployTrainedModelTask(request.getId(), tasks); + if (deployTrainedModelTask == null) { + listener.onResponse(new UndeployTrainedModelAction.Response(true)); + return; + } + normalUndeploy(task, deployTrainedModelTask, request, listener); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request( + request.getId(), null, Collections.emptySet()); + getModelRequest.setAllowNoResources(request.isAllowNoMatch()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + private void redirectToMasterNode(DiscoveryNode masterNode, UndeployTrainedModelAction.Request request, + ActionListener listener) { + if (masterNode == null) { + listener.onFailure(new MasterNotDiscoveredException()); + } else { + transportService.sendRequest(masterNode, actionName, request, + new ActionListenerResponseHandler<>(listener, UndeployTrainedModelAction.Response::new)); + } + } + + private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask, + UndeployTrainedModelAction.Request request, ActionListener listener) { + request.setNodes(deployTrainedModelTask.getExecutorNode()); + + ActionListener finalListener = ActionListener.wrap( + r -> waitForTaskRemoved(Collections.singleton(deployTrainedModelTask.getId()), request, r, listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { + // A node has dropped out of the cluster since we started executing the requests. + // Since undeploying an already undeployed trained model is not an error we can try again. + // The tasks that were running on the node that dropped out of the cluster + // will just have their persistent tasks cancelled. Tasks that were stopped + // by the previous attempt will be noops in the subsequent attempt. + doExecute(task, request, listener); + } else { + listener.onFailure(e); + } + } + ); + + super.doExecute(task, request, finalListener); + } + + void waitForTaskRemoved(Set taskIds, UndeployTrainedModelAction.Request request, + UndeployTrainedModelAction.Response response, + ActionListener listener) { + persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> + persistentTasks.findTasks(MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), + request.getTimeout(), ActionListener.wrap( + booleanResponse -> { + listener.onResponse(response); + }, + listener::onFailure + ) + ); + } + + @Override + protected UndeployTrainedModelAction.Response newResponse(UndeployTrainedModelAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + return new UndeployTrainedModelAction.Response(true); + } + } + + @Override + protected void taskOperation(UndeployTrainedModelAction.Request request, DeployTrainedModelTask task, + ActionListener listener) { + DeployTrainedModelTaskState undeployingState = new DeployTrainedModelTaskState( + DeployTrainedModelState.UNDEPLOYING, task.getAllocationId(), "api"); + task.updatePersistentTaskState(undeployingState, ActionListener.wrap( + updatedTask -> { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() throws Exception { + task.stop("undeploy_trained_model (api)"); + listener.onResponse(new UndeployTrainedModelAction.Response(true)); + } + }); + }, + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + // the task has disappeared so must have stopped + listener.onResponse(new UndeployTrainedModelAction.Response(true)); + } else { + listener.onFailure(e); + } + } + )); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java new file mode 100644 index 0000000000000..03055e622464d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction.TaskParams; + +import java.util.Map; + +public class DeployTrainedModelTask extends AllocatedPersistentTask implements DeployTrainedModelAction.TaskMatcher { + + private static final Logger logger = LogManager.getLogger(DeployTrainedModelTask.class); + + private final TaskParams params; + private volatile boolean isStopping; + private volatile DeploymentManager manager; + + public DeployTrainedModelTask(long id, String type, String action, TaskId parentTask, Map headers, + TaskParams taskParams) { + super(id, type, action, MlTasks.DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); + this.params = taskParams; + } + + public String getModelId() { + return params.getModelId(); + } + + public void stop(String reason) { + isStopping = true; + logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + + assert manager != null : "manager should not be unset when stop is called"; + manager.undeployModel(this); + markAsCompleted(); + } + + public void setDeploymentManager(DeploymentManager manager) { + this.manager = manager; + } + + @Override + protected void onCancelled() { + String reason = getReasonCancelled(); + stop(reason); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java new file mode 100644 index 0000000000000..6537f1ff6ed1b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public class DeploymentManager { + + private static final Logger logger = LogManager.getLogger(DeploymentManager.class); + + private final PyTorchProcessFactory pyTorchProcessFactory; + private final ExecutorService executorServiceForProcess; + private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); + + public DeploymentManager(ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) { + this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory); + this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); + } + + public void deployModel(DeployTrainedModelTask task) { + logger.info("[{}] Deploying model", task.getModelId()); + + ProcessContext processContext = new ProcessContext(task.getModelId()); + + if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { + throw ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId()); + } + + processContext.startProcess(); + } + + public void undeployModel(DeployTrainedModelTask task) { + ProcessContext processContext; + synchronized (processContextByAllocation) { + processContext = processContextByAllocation.get(task.getAllocationId()); + } + if (processContext != null) { + logger.debug("[{}] Undeploying model", task.getModelId()); + processContext.killProcess(); + } else { + logger.debug("[{}] No process context to stop", task.getModelId()); + } + } + + class ProcessContext { + + private final String modelId; + private final SetOnce process = new SetOnce<>(); + + ProcessContext(String modelId) { + this.modelId = Objects.requireNonNull(modelId); + } + + synchronized void startProcess() { + process.set(pyTorchProcessFactory.createProcess(modelId, executorServiceForProcess, onProcessCrash())); + } + + synchronized void killProcess() { + if (process.get() == null) { + return; + } + try { + process.get().kill(true); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] Failed to kill process", modelId), e); + } + } + + private Consumer onProcessCrash() { + return reason -> { + logger.error("[{}] process crashed due to reason [{}]", modelId, reason); + }; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java new file mode 100644 index 0000000000000..4064f34475bc8 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.function.Consumer; + +public class NativePyTorchProcess extends AbstractNativeProcess implements PyTorchProcess { + + private static final String NAME = "pytorch_inference"; + + protected NativePyTorchProcess(String jobId, NativeController nativeController, ProcessPipes processPipes, int numberOfFields, + List filesToDelete, Consumer onProcessCrash) { + super(jobId, nativeController, processPipes, numberOfFields, filesToDelete, onProcessCrash); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() throws IOException { + // Nothing to persist + } + + @Override + public void persistState(long snapshotTimestampMs, String snapshotId, String snapshotDescription) throws IOException { + // Nothing to persist + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java new file mode 100644 index 0000000000000..6c9b050e8e0e6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.env.Environment; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; + +import java.io.IOException; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public class NativePyTorchProcessFactory implements PyTorchProcessFactory { + + private static final Logger logger = LogManager.getLogger(NativePyTorchProcessFactory.class); + + private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper(); + + private final Environment env; + private final NativeController nativeController; + private volatile Duration processConnectTimeout; + + public NativePyTorchProcessFactory(Environment env, + NativeController nativeController, + ClusterService clusterService) { + this.env = Objects.requireNonNull(env); + this.nativeController = Objects.requireNonNull(nativeController); + setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings())); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT, + this::setProcessConnectTimeout); + } + + void setProcessConnectTimeout(TimeValue processConnectTimeout) { + this.processConnectTimeout = Duration.ofMillis(processConnectTimeout.getMillis()); + } + + @Override + public PyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash) { + List filesToDelete = new ArrayList<>(); + ProcessPipes processPipes = new ProcessPipes( + env, + NAMED_PIPE_HELPER, + processConnectTimeout, + PyTorchBuilder.PROCESS_NAME, + modelId, + null, + false, + true, + true, + true, + false + ); + + executeProcess(processPipes, filesToDelete); + + NativePyTorchProcess process = new NativePyTorchProcess(modelId, nativeController, processPipes, 0, filesToDelete, onProcessCrash); + + try { + process.start(executorService); + } catch(IOException | EsRejectedExecutionException e) { + String msg = "Failed to connect to pytorch process for job " + modelId; + logger.error(msg); + try { + IOUtils.close(process); + } catch (IOException ioe) { + logger.error("Can't close pytorch process", ioe); + } + throw ExceptionsHelper.serverError(msg, e); + } + return process; + } + + private void executeProcess(ProcessPipes processPipes, List filesToDelete) { + PyTorchBuilder pyTorchBuilder = new PyTorchBuilder(env::tmpFile, nativeController, processPipes, filesToDelete); + try { + pyTorchBuilder.build(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + throw ExceptionsHelper.serverError("Failed to launch PyTorch process"); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java new file mode 100644 index 0000000000000..60f50a8f70507 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class PyTorchBuilder { + + public static final String PROCESS_NAME = "pytorch_inference"; + private static final String PROCESS_PATH = "./" + PROCESS_NAME; + + private final Supplier tempDirPathSupplier; + private final NativeController nativeController; + private final ProcessPipes processPipes; + private final List filesToDelete; + + public PyTorchBuilder(Supplier tempDirPathSupplier, NativeController nativeController, ProcessPipes processPipes, + List filesToDelete) { + this.tempDirPathSupplier = Objects.requireNonNull(tempDirPathSupplier); + this.nativeController = Objects.requireNonNull(nativeController); + this.processPipes = Objects.requireNonNull(processPipes); + this.filesToDelete = Objects.requireNonNull(filesToDelete); + } + + public void build() throws IOException, InterruptedException { + List command = buildCommand(); + processPipes.addArgs(command); + nativeController.startProcess(command); + } + + private List buildCommand() { + List command = new ArrayList<>(); + command.add(PROCESS_PATH); + return command; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java new file mode 100644 index 0000000000000..91c4669bc0160 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.NativeProcess; + +public interface PyTorchProcess extends NativeProcess { + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java new file mode 100644 index 0000000000000..4d22a80f433a5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public interface PyTorchProcessFactory { + + PyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java new file mode 100644 index 0000000000000..c812e490217ed --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class PyTorchProcessManager { + + private static final Logger logger = LogManager.getLogger(PyTorchProcessManager.class); + + public PyTorchProcessManager() { + + } + + public void start(String taskId) { + + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java new file mode 100644 index 0000000000000..f11e302f9ad8e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; + +public class RestDeployTrainedModelAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_deploy_trained_model_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route(POST, + MachineLearning.BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_deploy")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + DeployTrainedModelAction.Request request = new DeployTrainedModelAction.Request(modelId); + return channel -> client.execute(DeployTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java new file mode 100644 index 0000000000000..826d910c30bf6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestUndeployTrainedModelAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_undeploy_trained_model_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route( + POST, + BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_undeploy") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + UndeployTrainedModelAction.Request request = new UndeployTrainedModelAction.Request(modelId); + return channel -> client.execute(UndeployTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} From bc8b898c41063f58170fa6f307d43decd65e5313 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 23 Mar 2021 12:40:06 +0200 Subject: [PATCH 2/7] Wait for deployment started --- .../ml/action/DeployTrainedModelAction.java | 21 ++++- .../DeployTrainedModelTaskState.java | 8 ++ .../TransportDeployTrainedModelAction.java | 94 ++++++++++++++++++- .../deployment/DeploymentManager.java | 10 ++ 4 files changed, 127 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java index c1a6e7fc1ddf9..05a065b7be1b1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.persistent.PersistentTaskParams; @@ -25,12 +26,15 @@ import java.io.IOException; import java.util.Objects; +import java.util.concurrent.TimeUnit; public class DeployTrainedModelAction extends ActionType { public static final DeployTrainedModelAction INSTANCE = new DeployTrainedModelAction(); public static final String NAME = "cluster:admin/xpack/ml/inference/trained_model/deploy"; + public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS); + public DeployTrainedModelAction() { super(NAME, NodeAcknowledgedResponse::new); } @@ -38,8 +42,10 @@ public DeployTrainedModelAction() { public static class Request extends MasterNodeRequest implements ToXContentObject { private static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField TIMEOUT = new ParseField("timeout"); private String modelId; + private TimeValue timeout = DEFAULT_TIMEOUT; public Request(String modelId) { setModelId(modelId); @@ -48,6 +54,7 @@ public Request(String modelId) { public Request(StreamInput in) throws IOException { super(in); modelId = in.readString(); + timeout = in.readTimeValue(); } public final void setModelId(String modelId) { @@ -58,15 +65,25 @@ public String getModelId() { return modelId; } + public void setTimeout(TimeValue timeout) { + this.timeout = timeout; + } + + public TimeValue getTimeout() { + return timeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(modelId); + out.writeTimeValue(timeout); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); return builder; } @@ -77,7 +94,7 @@ public ActionRequestValidationException validate() { @Override public int hashCode() { - return Objects.hash(modelId); + return Objects.hash(modelId, timeout); } @Override @@ -89,7 +106,7 @@ public boolean equals(Object obj) { return false; } Request other = (Request) obj; - return Objects.equals(modelId, other.modelId); + return Objects.equals(modelId, other.modelId) && Objects.equals(timeout, other.timeout); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java index d98dc2c0fa14f..2d199165cb8c7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java @@ -63,6 +63,14 @@ public DeployTrainedModelTaskState(StreamInput in) throws IOException { this.reason = in.readOptionalString(); } + public DeployTrainedModelState getState() { + return state; + } + + public String getReason() { + return reason; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java index aff8b142d7f1b..3a85c1a0e6fde 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeAction; @@ -22,12 +24,15 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; @@ -52,6 +57,7 @@ import java.util.Collections; import java.util.Map; import java.util.Objects; +import java.util.function.Predicate; public class TransportDeployTrainedModelAction extends TransportMasterNodeAction { @@ -83,6 +89,22 @@ protected void masterOperation(Task task, DeployTrainedModelAction.Request reque return; } + ActionListener> waitForDeploymentToStart = + ActionListener.wrap( + startedTask -> waitForDeploymentStarted(startedTask, request.getTimeout(), listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + e = new ElasticsearchStatusException( + "Cannot start deployment [{}] because it has already been started", + RestStatus.CONFLICT, + e, + request.getModelId() + ); + } + listener.onFailure(e); + } + ); + ActionListener getModelListener = ActionListener.wrap( getModelResponse -> { if (getModelResponse.getResources().results().size() > 1) { @@ -95,10 +117,7 @@ protected void masterOperation(Task task, DeployTrainedModelAction.Request reque MlTasks.deployTrainedModelTaskId(request.getModelId()), MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, new TaskParams(request.getModelId()), - ActionListener.wrap( - response -> listener.onResponse(new NodeAcknowledgedResponse(true, "")), - listener::onFailure - ) + waitForDeploymentToStart ); }, listener::onFailure @@ -109,6 +128,27 @@ protected void masterOperation(Task task, DeployTrainedModelAction.Request reque client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); } + private void waitForDeploymentStarted(PersistentTasksCustomMetadata.PersistentTask task, + TimeValue timeout, ActionListener listener) { + DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(); + persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, timeout, + new PersistentTasksService.WaitForPersistentTaskListener() { + @Override + public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (predicate.exception != null) { + + } else { + listener.onResponse(new NodeAcknowledgedResponse(true, predicate.node)); + } + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + @Override protected ClusterBlockException checkBlock(DeployTrainedModelAction.Request request, ClusterState state) { // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, @@ -117,6 +157,52 @@ protected ClusterBlockException checkBlock(DeployTrainedModelAction.Request requ return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); } + private static class DeploymentStartedPredicate implements Predicate> { + + private volatile Exception exception; + private volatile String node = ""; + private volatile String assignmentExplanation; + + @Override + public boolean test(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (persistentTask == null) { + return false; + } + + PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment(); + + String reason = "__unknown__"; + + if (assignment != null) { + if (assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { + return true; + } + if (assignment.equals(PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) { + exception = new ElasticsearchStatusException("Could not start trained model deployment, allocation explanation [{}]", + RestStatus.TOO_MANY_REQUESTS, assignment.getExplanation()); + return true; + } + } + + DeployTrainedModelTaskState taskState = (DeployTrainedModelTaskState) persistentTask.getState(); + reason = taskState != null ? taskState.getReason() : reason; + DeployTrainedModelState deploymentState = taskState == null ? DeployTrainedModelState.DEPLOYED : taskState.getState(); + switch (deploymentState) { + case DEPLOYED: + node = persistentTask.getExecutorNode(); + return true; + case DEPLOYING: + case UNDEPLOYING: + case UNDEPLOYED: + return false; + default: + exception = ExceptionsHelper.serverError("Unexpected task state [{}] with reason [{}] while waiting to be started", + taskState.getState(), reason); + return true; + } + } + } + public static class TaskExecutor extends AbstractJobPersistentTasksExecutor { private final DeploymentManager manager; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 6537f1ff6ed1b..b450ebcf91ed5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -11,7 +11,10 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; +import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess; @@ -47,6 +50,13 @@ public void deployModel(DeployTrainedModelTask task) { } processContext.startProcess(); + + DeployTrainedModelTaskState startedState = new DeployTrainedModelTaskState( + DeployTrainedModelState.DEPLOYED, task.getAllocationId(), null); + task.updatePersistentTaskState(startedState, ActionListener.wrap( + response -> logger.info("[{}] trained model deployment started", task.getModelId()), + task::markAsFailed + )); } public void undeployModel(DeployTrainedModelTask task) { From 7cc739ff285ef9e12294e6d4211d9a8846d01d50 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 23 Mar 2021 12:55:26 +0200 Subject: [PATCH 3/7] Rename to start/stop trained model deployment --- ...=> StartTrainedModelDeploymentAction.java} | 8 +-- ... => StopTrainedModelDeploymentAction.java} | 14 ++-- ....java => TrainedModelDeploymentState.java} | 10 +-- ...a => TrainedModelDeploymentTaskState.java} | 20 +++--- .../xpack/ml/MachineLearning.java | 28 ++++---- ...ortStartTrainedModelDeploymentAction.java} | 57 +++++++-------- ...portStopTrainedModelDeploymentAction.java} | 71 ++++++++++--------- .../deployment/DeploymentManager.java | 12 ++-- ...k.java => TrainedModelDeploymentTask.java} | 12 ++-- ...estStartTrainedModelDeploymentAction.java} | 12 ++-- ...RestStopTrainedModelDeploymentAction.java} | 12 ++-- 11 files changed, 129 insertions(+), 127 deletions(-) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/{DeployTrainedModelAction.java => StartTrainedModelDeploymentAction.java} (94%) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/{UndeployTrainedModelAction.java => StopTrainedModelDeploymentAction.java} (90%) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/{DeployTrainedModelState.java => TrainedModelDeploymentState.java} (71%) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/{DeployTrainedModelTaskState.java => TrainedModelDeploymentTaskState.java} (79%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/{TransportDeployTrainedModelAction.java => TransportStartTrainedModelDeploymentAction.java} (81%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/{TransportUndeployTrainedModelAction.java => TransportStopTrainedModelDeploymentAction.java} (67%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/{DeployTrainedModelTask.java => TrainedModelDeploymentTask.java} (71%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/{RestDeployTrainedModelAction.java => RestStartTrainedModelDeploymentAction.java} (65%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/{RestUndeployTrainedModelAction.java => RestStopTrainedModelDeploymentAction.java} (67%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java similarity index 94% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 05a065b7be1b1..e2b90bad16e3d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeployTrainedModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -28,14 +28,14 @@ import java.util.Objects; import java.util.concurrent.TimeUnit; -public class DeployTrainedModelAction extends ActionType { +public class StartTrainedModelDeploymentAction extends ActionType { - public static final DeployTrainedModelAction INSTANCE = new DeployTrainedModelAction(); - public static final String NAME = "cluster:admin/xpack/ml/inference/trained_model/deploy"; + public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start"; public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS); - public DeployTrainedModelAction() { + public StartTrainedModelDeploymentAction() { super(NAME, NodeAcknowledgedResponse::new); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java similarity index 90% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java index 76d1e690e900d..2fd52f5baa5d0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UndeployTrainedModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java @@ -23,13 +23,13 @@ import java.io.IOException; import java.util.Objects; -public class UndeployTrainedModelAction extends ActionType { +public class StopTrainedModelDeploymentAction extends ActionType { - public static final UndeployTrainedModelAction INSTANCE = new UndeployTrainedModelAction(); - public static final String NAME = "cluster:admin/xpack/ml/trained_models/undeploy"; + public static final StopTrainedModelDeploymentAction INSTANCE = new StopTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/stop"; - public UndeployTrainedModelAction() { - super(NAME, UndeployTrainedModelAction.Response::new); + public StopTrainedModelDeploymentAction() { + super(NAME, StopTrainedModelDeploymentAction.Response::new); } public static class Request extends BaseTasksRequest implements ToXContentObject { @@ -78,7 +78,7 @@ public boolean isForce() { @Override public boolean match(Task task) { - return DeployTrainedModelAction.TaskMatcher.match(task, id); + return StartTrainedModelDeploymentAction.TaskMatcher.match(task, id); } @Override @@ -140,7 +140,7 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); toXContentCommon(builder, params); - builder.field("undeployed", undeployed); + builder.field("stopped", undeployed); builder.endObject(); return builder; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java similarity index 71% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java index 03c136e0dcee3..b63b903809e3d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java @@ -14,16 +14,16 @@ import java.io.IOException; import java.util.Locale; -public enum DeployTrainedModelState implements Writeable { +public enum TrainedModelDeploymentState implements Writeable { - DEPLOYING, DEPLOYED, UNDEPLOYING, UNDEPLOYED; + STARTING, STARTED, STOPPING, STOPPED; - public static DeployTrainedModelState fromString(String name) { + public static TrainedModelDeploymentState fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } - public static DeployTrainedModelState fromStream(StreamInput in) throws IOException { - return in.readEnum(DeployTrainedModelState.class); + public static TrainedModelDeploymentState fromStream(StreamInput in) throws IOException { + return in.readEnum(TrainedModelDeploymentState.class); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java similarity index 79% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java index 2d199165cb8c7..4fd7beb5feed0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/DeployTrainedModelTaskState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.util.Objects; -public class DeployTrainedModelTaskState implements PersistentTaskState { +public class TrainedModelDeploymentTaskState implements PersistentTaskState { public static final String NAME = MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME; @@ -29,13 +29,13 @@ public class DeployTrainedModelTaskState implements PersistentTaskState { private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); private static ParseField REASON = new ParseField("reason"); - private final DeployTrainedModelState state; + private final TrainedModelDeploymentState state; private final long allocationId; private final String reason; - private static final ConstructingObjectParser PARSER = + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, - a -> new DeployTrainedModelTaskState((DeployTrainedModelState) a[0], (long) a[1], (String) a[2])); + a -> new TrainedModelDeploymentTaskState((TrainedModelDeploymentState) a[0], (long) a[1], (String) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsState::fromString, STATE); @@ -43,7 +43,7 @@ public class DeployTrainedModelTaskState implements PersistentTaskState { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); } - public static DeployTrainedModelTaskState fromXContent(XContentParser parser) { + public static TrainedModelDeploymentTaskState fromXContent(XContentParser parser) { try { return PARSER.parse(parser, null); } catch (IOException e) { @@ -51,19 +51,19 @@ public static DeployTrainedModelTaskState fromXContent(XContentParser parser) { } } - public DeployTrainedModelTaskState(DeployTrainedModelState state, long allocationId, @Nullable String reason) { + public TrainedModelDeploymentTaskState(TrainedModelDeploymentState state, long allocationId, @Nullable String reason) { this.state = Objects.requireNonNull(state); this.allocationId = allocationId; this.reason = reason; } - public DeployTrainedModelTaskState(StreamInput in) throws IOException { - this.state = DeployTrainedModelState.fromStream(in); + public TrainedModelDeploymentTaskState(StreamInput in) throws IOException { + this.state = TrainedModelDeploymentState.fromStream(in); this.allocationId = in.readLong(); this.reason = in.readOptionalString(); } - public DeployTrainedModelState getState() { + public TrainedModelDeploymentState getState() { return state; } @@ -99,7 +99,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - DeployTrainedModelTaskState that = (DeployTrainedModelTaskState) o; + TrainedModelDeploymentTaskState that = (TrainedModelDeploymentTaskState) o; return allocationId == that.allocationId && state == that.state && Objects.equals(reason, that.reason); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 38b3d0af7f8e6..8a7192332cb28 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,7 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; @@ -142,7 +142,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; -import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -159,7 +159,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -178,7 +178,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; -import org.elasticsearch.xpack.ml.action.TransportDeployTrainedModelAction; +import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportExplainDataFrameAnalyticsAction; @@ -225,7 +225,7 @@ import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; -import org.elasticsearch.xpack.ml.action.TransportUndeployTrainedModelAction; +import org.elasticsearch.xpack.ml.action.TransportStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDatafeedAction; @@ -333,12 +333,12 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction; -import org.elasticsearch.xpack.ml.rest.inference.RestDeployTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; -import org.elasticsearch.xpack.ml.rest.inference.RestUndeployTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -893,7 +893,7 @@ public List> getPersistentTasksExecutor(ClusterServic memoryTracker.get(), expressionResolver, client), - new TransportDeployTrainedModelAction.TaskExecutor(settings, + new TransportStartTrainedModelDeploymentAction.TaskExecutor(settings, clusterService, expressionResolver, memoryTracker.get(), @@ -973,8 +973,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutTrainedModelAliasAction(), new RestDeleteTrainedModelAliasAction(), new RestPreviewDataFrameAnalyticsAction(), - new RestDeployTrainedModelAction(), - new RestUndeployTrainedModelAction(), + new RestStartTrainedModelDeploymentAction(), + new RestStopTrainedModelDeploymentAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1061,8 +1061,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class), new ActionHandler<>(DeleteTrainedModelAliasAction.INSTANCE, TransportDeleteTrainedModelAliasAction.class), new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class), - new ActionHandler<>(DeployTrainedModelAction.INSTANCE, TransportDeployTrainedModelAction.class), - new ActionHandler<>(UndeployTrainedModelAction.INSTANCE, TransportUndeployTrainedModelAction.class), + new ActionHandler<>(StartTrainedModelDeploymentAction.INSTANCE, TransportStartTrainedModelDeploymentAction.class), + new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class), usageAction, infoAction); } @@ -1198,7 +1198,7 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, SnapshotUpgradeTaskParams::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, - DeployTrainedModelAction.TaskParams::new)); + StartTrainedModelDeploymentAction.TaskParams::new)); // Persistent task states namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new)); @@ -1209,7 +1209,7 @@ public List getNamedWriteables() { SnapshotUpgradeTaskState.NAME, SnapshotUpgradeTaskState::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, - DeployTrainedModelTaskState.NAME, DeployTrainedModelTaskState::new)); + TrainedModelDeploymentTaskState.NAME, TrainedModelDeploymentTaskState::new)); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java similarity index 81% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 3a85c1a0e6fde..f1fa737c24416 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeployTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -39,16 +39,16 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MlTasks; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction.TaskParams; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.inference.deployment.DeployTrainedModelTask; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; @@ -59,29 +59,30 @@ import java.util.Objects; import java.util.function.Predicate; -public class TransportDeployTrainedModelAction - extends TransportMasterNodeAction { +public class TransportStartTrainedModelDeploymentAction + extends TransportMasterNodeAction { - private static final Logger logger = LogManager.getLogger(TransportDeployTrainedModelAction.class); + private static final Logger logger = LogManager.getLogger(TransportStartTrainedModelDeploymentAction.class); private final XPackLicenseState licenseState; private final Client client; private final PersistentTasksService persistentTasksService; @Inject - public TransportDeployTrainedModelAction(TransportService transportService, Client client, ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, - IndexNameExpressionResolver indexNameExpressionResolver, - PersistentTasksService persistentTasksService) { - super(DeployTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, - DeployTrainedModelAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, ThreadPool.Names.SAME); + public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, + IndexNameExpressionResolver indexNameExpressionResolver, + PersistentTasksService persistentTasksService) { + super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters, + StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, + ThreadPool.Names.SAME); this.licenseState = Objects.requireNonNull(licenseState); this.client = Objects.requireNonNull(client); this.persistentTasksService = Objects.requireNonNull(persistentTasksService); } @Override - protected void masterOperation(Task task, DeployTrainedModelAction.Request request, ClusterState state, + protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Request request, ClusterState state, ActionListener listener) throws Exception { logger.debug(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) { @@ -150,7 +151,7 @@ public void onFailure(Exception e) { } @Override - protected ClusterBlockException checkBlock(DeployTrainedModelAction.Request request, ClusterState state) { + protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) { // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, // then delegating to PersistentTasksService doesn't make a whole lot of sense, // because PersistentTasksService will then fail. @@ -184,16 +185,16 @@ public boolean test(PersistentTasksCustomMetadata.PersistentTask persistentTa } } - DeployTrainedModelTaskState taskState = (DeployTrainedModelTaskState) persistentTask.getState(); + TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) persistentTask.getState(); reason = taskState != null ? taskState.getReason() : reason; - DeployTrainedModelState deploymentState = taskState == null ? DeployTrainedModelState.DEPLOYED : taskState.getState(); + TrainedModelDeploymentState deploymentState = taskState == null ? TrainedModelDeploymentState.STARTED : taskState.getState(); switch (deploymentState) { - case DEPLOYED: + case STARTED: node = persistentTask.getExecutorNode(); return true; - case DEPLOYING: - case UNDEPLOYING: - case UNDEPLOYED: + case STARTING: + case STOPPING: + case STOPPED: return false; default: exception = ExceptionsHelper.serverError("Unexpected task state [{}] with reason [{}] while waiting to be started", @@ -223,7 +224,7 @@ protected AllocatedPersistentTask createTask( long id, String type, String action, TaskId parentTaskId, PersistentTasksCustomMetadata.PersistentTask persistentTask, Map headers) { - return new DeployTrainedModelTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); } @Override @@ -261,13 +262,13 @@ public static String nodeFilter(DiscoveryNode node, TaskParams params) { @Override protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) { - DeployTrainedModelTask deployTrainedModelTask = (DeployTrainedModelTask) task; - deployTrainedModelTask.setDeploymentManager(manager); + TrainedModelDeploymentTask trainedModelDeploymentTask = (TrainedModelDeploymentTask) task; + trainedModelDeploymentTask.setDeploymentManager(manager); - DeployTrainedModelTaskState deployingState = new DeployTrainedModelTaskState( - DeployTrainedModelState.DEPLOYING, task.getAllocationId(), null); + TrainedModelDeploymentTaskState deployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTING, task.getAllocationId(), null); task.updatePersistentTaskState(deployingState, ActionListener.wrap( - response -> manager.deployModel(deployTrainedModelTask), + response -> manager.deployModel(trainedModelDeploymentTask), task::markAsFailed )); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java similarity index 67% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index d7bcca5739cfd..2719ded55a994 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUndeployTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -31,41 +31,41 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.inference.deployment.DeployTrainedModelTask; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import java.util.Collections; import java.util.List; import java.util.Set; -public class TransportUndeployTrainedModelAction extends TransportTasksAction { +public class TransportStopTrainedModelDeploymentAction extends TransportTasksAction { - private static final Logger logger = LogManager.getLogger(TransportUndeployTrainedModelAction.class); + private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); private final Client client; private final ThreadPool threadPool; private final PersistentTasksService persistentTasksService; @Inject - public TransportUndeployTrainedModelAction(String actionName, ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters, Client client, ThreadPool threadPool, - PersistentTasksService persistentTasksService) { - super(actionName, clusterService, transportService, actionFilters, UndeployTrainedModelAction.Request::new, - UndeployTrainedModelAction.Response::new, UndeployTrainedModelAction.Response::new, ThreadPool.Names.SAME); + public TransportStopTrainedModelDeploymentAction(String actionName, ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters, Client client, ThreadPool threadPool, + PersistentTasksService persistentTasksService) { + super(actionName, clusterService, transportService, actionFilters, StopTrainedModelDeploymentAction.Request::new, + StopTrainedModelDeploymentAction.Response::new, StopTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); this.client = client; this.threadPool = threadPool; this.persistentTasksService = persistentTasksService; } @Override - protected void doExecute(Task task, UndeployTrainedModelAction.Request request, - ActionListener listener) { + protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { ClusterState state = clusterService.state(); DiscoveryNodes nodes = state.nodes(); if (nodes.isLocalNodeElectedMaster() == false) { @@ -79,7 +79,7 @@ protected void doExecute(Task task, UndeployTrainedModelAction.Request request, getModelsResponse -> { List models = getModelsResponse.getResources().results(); if (models.isEmpty()) { - listener.onResponse(new UndeployTrainedModelAction.Response(true)); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); return; } if (models.size() > 1) { @@ -92,7 +92,7 @@ protected void doExecute(Task task, UndeployTrainedModelAction.Request request, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = MlTasks.getDeployTrainedModelTask(request.getId(), tasks); if (deployTrainedModelTask == null) { - listener.onResponse(new UndeployTrainedModelAction.Response(true)); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); return; } normalUndeploy(task, deployTrainedModelTask, request, listener); @@ -106,21 +106,22 @@ protected void doExecute(Task task, UndeployTrainedModelAction.Request request, client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); } - private void redirectToMasterNode(DiscoveryNode masterNode, UndeployTrainedModelAction.Request request, - ActionListener listener) { + private void redirectToMasterNode(DiscoveryNode masterNode, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { if (masterNode == null) { listener.onFailure(new MasterNotDiscoveredException()); } else { transportService.sendRequest(masterNode, actionName, request, - new ActionListenerResponseHandler<>(listener, UndeployTrainedModelAction.Response::new)); + new ActionListenerResponseHandler<>(listener, StopTrainedModelDeploymentAction.Response::new)); } } private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask, - UndeployTrainedModelAction.Request request, ActionListener listener) { + StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { request.setNodes(deployTrainedModelTask.getExecutorNode()); - ActionListener finalListener = ActionListener.wrap( + ActionListener finalListener = ActionListener.wrap( r -> waitForTaskRemoved(Collections.singleton(deployTrainedModelTask.getId()), request, r, listener), e -> { if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { @@ -139,9 +140,9 @@ private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentT super.doExecute(task, request, finalListener); } - void waitForTaskRemoved(Set taskIds, UndeployTrainedModelAction.Request request, - UndeployTrainedModelAction.Response response, - ActionListener listener) { + void waitForTaskRemoved(Set taskIds, StopTrainedModelDeploymentAction.Request request, + StopTrainedModelDeploymentAction.Response response, + ActionListener listener) { persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> persistentTasks.findTasks(MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), request.getTimeout(), ActionListener.wrap( @@ -154,24 +155,24 @@ void waitForTaskRemoved(Set taskIds, UndeployTrainedModelAction.Request } @Override - protected UndeployTrainedModelAction.Response newResponse(UndeployTrainedModelAction.Request request, - List tasks, - List taskOperationFailures, - List failedNodeExceptions) { + protected StopTrainedModelDeploymentAction.Response newResponse(StopTrainedModelDeploymentAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { if (taskOperationFailures.isEmpty() == false) { throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); } else if (failedNodeExceptions.isEmpty() == false) { throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); } else { - return new UndeployTrainedModelAction.Response(true); + return new StopTrainedModelDeploymentAction.Response(true); } } @Override - protected void taskOperation(UndeployTrainedModelAction.Request request, DeployTrainedModelTask task, - ActionListener listener) { - DeployTrainedModelTaskState undeployingState = new DeployTrainedModelTaskState( - DeployTrainedModelState.UNDEPLOYING, task.getAllocationId(), "api"); + protected void taskOperation(StopTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, + ActionListener listener) { + TrainedModelDeploymentTaskState undeployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STOPPING, task.getAllocationId(), "api"); task.updatePersistentTaskState(undeployingState, ActionListener.wrap( updatedTask -> { threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { @@ -183,14 +184,14 @@ public void onFailure(Exception e) { @Override protected void doRun() throws Exception { task.stop("undeploy_trained_model (api)"); - listener.onResponse(new UndeployTrainedModelAction.Response(true)); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); } }); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { // the task has disappeared so must have stopped - listener.onResponse(new UndeployTrainedModelAction.Response(true)); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); } else { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index b450ebcf91ed5..4ed6bfb05033f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -13,8 +13,8 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelState; -import org.elasticsearch.xpack.core.ml.inference.deployment.DeployTrainedModelTaskState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess; @@ -40,7 +40,7 @@ public DeploymentManager(ThreadPool threadPool, PyTorchProcessFactory pyTorchPro this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); } - public void deployModel(DeployTrainedModelTask task) { + public void deployModel(TrainedModelDeploymentTask task) { logger.info("[{}] Deploying model", task.getModelId()); ProcessContext processContext = new ProcessContext(task.getModelId()); @@ -51,15 +51,15 @@ public void deployModel(DeployTrainedModelTask task) { processContext.startProcess(); - DeployTrainedModelTaskState startedState = new DeployTrainedModelTaskState( - DeployTrainedModelState.DEPLOYED, task.getAllocationId(), null); + TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); task.updatePersistentTaskState(startedState, ActionListener.wrap( response -> logger.info("[{}] trained model deployment started", task.getModelId()), task::markAsFailed )); } - public void undeployModel(DeployTrainedModelTask task) { + public void undeployModel(TrainedModelDeploymentTask task) { ProcessContext processContext; synchronized (processContextByAllocation) { processContext = processContextByAllocation.get(task.getAllocationId()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java similarity index 71% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 03055e622464d..955afd14ebcaa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeployTrainedModelTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -12,21 +12,21 @@ import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction.TaskParams; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import java.util.Map; -public class DeployTrainedModelTask extends AllocatedPersistentTask implements DeployTrainedModelAction.TaskMatcher { +public class TrainedModelDeploymentTask extends AllocatedPersistentTask implements StartTrainedModelDeploymentAction.TaskMatcher { - private static final Logger logger = LogManager.getLogger(DeployTrainedModelTask.class); + private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); private final TaskParams params; private volatile boolean isStopping; private volatile DeploymentManager manager; - public DeployTrainedModelTask(long id, String type, String action, TaskId parentTask, Map headers, - TaskParams taskParams) { + public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map headers, + TaskParams taskParams) { super(id, type, action, MlTasks.DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); this.params = taskParams; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java similarity index 65% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index f11e302f9ad8e..be6dc1fb7615d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeployTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.core.ml.action.DeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.ml.MachineLearning; @@ -21,24 +21,24 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; -public class RestDeployTrainedModelAction extends BaseRestHandler { +public class RestStartTrainedModelDeploymentAction extends BaseRestHandler { @Override public String getName() { - return "xpack_ml_deploy_trained_model_action"; + return "xpack_ml_start_trained_models_deployment_action"; } @Override public List routes() { return Collections.singletonList( new Route(POST, - MachineLearning.BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_deploy")); + MachineLearning.BASE_PATH + "trained_models/deployment/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_start")); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); - DeployTrainedModelAction.Request request = new DeployTrainedModelAction.Request(modelId); - return channel -> client.execute(DeployTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + StartTrainedModelDeploymentAction.Request request = new StartTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java similarity index 67% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java index 826d910c30bf6..c92751811a911 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUndeployTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.core.ml.action.UndeployTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import java.io.IOException; @@ -21,11 +21,11 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; -public class RestUndeployTrainedModelAction extends BaseRestHandler { +public class RestStopTrainedModelDeploymentAction extends BaseRestHandler { @Override public String getName() { - return "xpack_ml_undeploy_trained_model_action"; + return "xpack_ml_stop_trained_models_deployment_action"; } @Override @@ -33,14 +33,14 @@ public List routes() { return Collections.singletonList( new Route( POST, - BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_undeploy") + BASE_PATH + "trained_models/deployment/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stop") ); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); - UndeployTrainedModelAction.Request request = new UndeployTrainedModelAction.Request(modelId); - return channel -> client.execute(UndeployTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StopTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } From 9f389249aa281b147cae8568d64b073a06c0a834 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 23 Mar 2021 16:03:54 +0200 Subject: [PATCH 4/7] More renamings plus cancelling of start action when assignment fails --- .../elasticsearch/xpack/core/ml/MlTasks.java | 10 +++--- .../StartTrainedModelDeploymentAction.java | 4 +-- .../TrainedModelDeploymentTaskState.java | 2 +- .../xpack/ml/MachineLearning.java | 2 +- ...portStartTrainedModelDeploymentAction.java | 33 ++++++++++++++----- ...sportStopTrainedModelDeploymentAction.java | 2 +- .../TrainedModelDeploymentTask.java | 2 +- 7 files changed, 36 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index 239638cd35205..a98f444accc27 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -28,13 +28,13 @@ public final class MlTasks { public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade"; - public static final String DEPLOY_TRAINED_MODEL_TASK_NAME = "xpack/ml/trained_models/deploy"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_ID_PREFIX = "job-snapshot-upgrade-"; - public static final String DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX = "deploy_trained_model-"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX = "trained_model_deployment-"; public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetadata.Assignment(null, @@ -78,8 +78,8 @@ public static String dataFrameAnalyticsId(String taskId) { return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length()); } - public static String deployTrainedModelTaskId(String modelId) { - return DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + modelId; + public static String trainedModelDeploymentTaskId(String modelId) { + return TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + modelId; } @Nullable @@ -109,7 +109,7 @@ public static PersistentTasksCustomMetadata.PersistentTask getSnapshotUpgrade @Nullable public static PersistentTasksCustomMetadata.PersistentTask getDeployTrainedModelTask(String modelId, @Nullable PersistentTasksCustomMetadata tasks) { - return tasks == null ? null : tasks.getTask(deployTrainedModelTaskId(modelId)); + return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId)); } /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index e2b90bad16e3d..81d9083753bab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -135,7 +135,7 @@ public String getModelId() { @Override public String getWriteableName() { - return MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME; + return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; } @Override @@ -178,7 +178,7 @@ static boolean match(Task task, String expectedId) { if (Strings.isAllOrWildcard(expectedId)) { return true; } - String expectedDescription = MlTasks.DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + expectedId; + String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId; return expectedDescription.equals(task.getDescription()); } return false; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java index 4fd7beb5feed0..29641b6b5512b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java @@ -23,7 +23,7 @@ public class TrainedModelDeploymentTaskState implements PersistentTaskState { - public static final String NAME = MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME; + public static final String NAME = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; private static ParseField STATE = new ParseField("state"); private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 8a7192332cb28..08dc1c69d4a43 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -1197,7 +1197,7 @@ public List getNamedWriteables() { StartDataFrameAnalyticsAction.TaskParams::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, SnapshotUpgradeTaskParams::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, StartTrainedModelDeploymentAction.TaskParams::new)); // Persistent task states diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index f1fa737c24416..6a797c666d3e3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -39,17 +39,17 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MlTasks; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; @@ -115,8 +115,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ return; } persistentTasksService.sendStartRequest( - MlTasks.deployTrainedModelTaskId(request.getModelId()), - MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, + MlTasks.trainedModelDeploymentTaskId(request.getModelId()), + MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, new TaskParams(request.getModelId()), waitForDeploymentToStart ); @@ -137,7 +137,7 @@ private void waitForDeploymentStarted(PersistentTasksCustomMetadata.PersistentTa @Override public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { if (predicate.exception != null) { - + cancelDeploymentStart(task, predicate.exception, listener); } else { listener.onResponse(new NodeAcknowledgedResponse(true, predicate.node)); } @@ -150,6 +150,23 @@ public void onFailure(Exception e) { }); } + private void cancelDeploymentStart( + PersistentTasksCustomMetadata.PersistentTask persistentTask, Exception exception, + ActionListener listener) { + persistentTasksService.sendRemoveRequest(persistentTask.getId(), ActionListener.wrap( + pTask -> listener.onFailure(exception), + e -> { + logger.error( + new ParameterizedMessage("[{}] Failed to cancel persistent task that could not be assigned due to [{}]", + persistentTask.getParams().getModelId(), exception.getMessage()), + e + ); + listener.onFailure(exception); + } + )); + + } + @Override protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) { // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, @@ -210,7 +227,7 @@ public static class TaskExecutor extends AbstractJobPersistentTasksExecutor nodeFilter(node, params)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 2719ded55a994..6d9722eef9aa5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -144,7 +144,7 @@ void waitForTaskRemoved(Set taskIds, StopTrainedModelDeploymentAction.Re StopTrainedModelDeploymentAction.Response response, ActionListener listener) { persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> - persistentTasks.findTasks(MlTasks.DEPLOY_TRAINED_MODEL_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), + persistentTasks.findTasks(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), request.getTimeout(), ActionListener.wrap( booleanResponse -> { listener.onResponse(response); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 955afd14ebcaa..b27caad45a02f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -27,7 +27,7 @@ public class TrainedModelDeploymentTask extends AllocatedPersistentTask implemen public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map headers, TaskParams taskParams) { - super(id, type, action, MlTasks.DEPLOY_TRAINED_MODEL_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); + super(id, type, action, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); this.params = taskParams; } From 072f12c12b39804be98675b684d112d991d98ffe Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 23 Mar 2021 16:05:48 +0200 Subject: [PATCH 5/7] Some more renaming --- .../TransportStartTrainedModelDeploymentAction.java | 2 +- .../xpack/ml/inference/deployment/DeploymentManager.java | 8 ++++---- .../inference/deployment/TrainedModelDeploymentTask.java | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 6a797c666d3e3..13fb27df529a7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -285,7 +285,7 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe TrainedModelDeploymentTaskState deployingState = new TrainedModelDeploymentTaskState( TrainedModelDeploymentState.STARTING, task.getAllocationId(), null); task.updatePersistentTaskState(deployingState, ActionListener.wrap( - response -> manager.deployModel(trainedModelDeploymentTask), + response -> manager.startDeployment(trainedModelDeploymentTask), task::markAsFailed )); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 4ed6bfb05033f..2776a4c0d167e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -40,8 +40,8 @@ public DeploymentManager(ThreadPool threadPool, PyTorchProcessFactory pyTorchPro this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); } - public void deployModel(TrainedModelDeploymentTask task) { - logger.info("[{}] Deploying model", task.getModelId()); + public void startDeployment(TrainedModelDeploymentTask task) { + logger.info("[{}] Starting model deployment", task.getModelId()); ProcessContext processContext = new ProcessContext(task.getModelId()); @@ -59,13 +59,13 @@ public void deployModel(TrainedModelDeploymentTask task) { )); } - public void undeployModel(TrainedModelDeploymentTask task) { + public void stopDeployment(TrainedModelDeploymentTask task) { ProcessContext processContext; synchronized (processContextByAllocation) { processContext = processContextByAllocation.get(task.getAllocationId()); } if (processContext != null) { - logger.debug("[{}] Undeploying model", task.getModelId()); + logger.debug("[{}] Stopping deployment", task.getModelId()); processContext.killProcess(); } else { logger.debug("[{}] No process context to stop", task.getModelId()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index b27caad45a02f..aa4f7f8d93998 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -40,7 +40,7 @@ public void stop(String reason) { logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); assert manager != null : "manager should not be unset when stop is called"; - manager.undeployModel(this); + manager.stopDeployment(this); markAsCompleted(); } From 6fe7992170b2787bab06a40ebbcd6cb4d7036af0 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 24 Mar 2021 13:03:48 +0200 Subject: [PATCH 6/7] Load model from hardcoded doc --- .../xpack/ml/MachineLearning.java | 2 +- .../deployment/DeploymentManager.java | 37 ++++++++++++++++++- .../pytorch/process/NativePyTorchProcess.java | 18 ++++++++- .../pytorch/process/PyTorchProcess.java | 3 ++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 08dc1c69d4a43..24455e5c9b34d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -787,7 +787,7 @@ public Collection createComponents(Client client, ClusterService cluster clusterService.getNodeName(), inferenceModelBreaker.get()); this.modelLoadingService.set(modelLoadingService); - this.deploymentManager.set(new DeploymentManager(threadPool, pyTorchProcessFactory)); + this.deploymentManager.set(new DeploymentManager(client, threadPool, pyTorchProcessFactory)); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 2776a4c0d167e..9a2e595bf0cbd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -12,6 +12,9 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; @@ -21,6 +24,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; import java.io.IOException; +import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -31,16 +35,24 @@ public class DeploymentManager { private static final Logger logger = LogManager.getLogger(DeploymentManager.class); + private final Client client; private final PyTorchProcessFactory pyTorchProcessFactory; + private final ExecutorService executorServiceForDeployment; private final ExecutorService executorServiceForProcess; private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); - public DeploymentManager(ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) { + public DeploymentManager(Client client, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) { + this.client = Objects.requireNonNull(client); this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory); + this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME); this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); } public void startDeployment(TrainedModelDeploymentTask task) { + executorServiceForDeployment.execute(() -> doStartDeployment(task)); + } + + private void doStartDeployment(TrainedModelDeploymentTask task) { logger.info("[{}] Starting model deployment", task.getModelId()); ProcessContext processContext = new ProcessContext(task.getModelId()); @@ -51,6 +63,12 @@ public void startDeployment(TrainedModelDeploymentTask task) { processContext.startProcess(); + try { + processContext.loadModel(); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] error loading model", task.getModelId()), e); + } + TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); task.updatePersistentTaskState(startedState, ActionListener.wrap( @@ -101,5 +119,22 @@ private Consumer onProcessCrash() { logger.error("[{}] process crashed due to reason [{}]", modelId, reason); }; } + + void loadModel() throws IOException { + // Here we should be reading the model location from the deployment config. + // Hardcoding this for the prototype. + String index = "test-models"; + String docId = "simple-model"; + + GetResponse modelGetResponse = client.get(new GetRequest(index, docId)).actionGet(); + if (modelGetResponse.isExists() == false) { + throw ExceptionsHelper.badRequestException("[{}] no model was found", modelId); + } + Map sourceAsMap = modelGetResponse.getSourceAsMap(); + int modelSizeAfterUnbase64 = (int) sourceAsMap.get("size"); + String modelBase64 = (String) sourceAsMap.get("model"); + process.get().loadModel(modelBase64, modelSizeAfterUnbase64); + } } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java index 4064f34475bc8..fdb65304f5571 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java @@ -10,9 +10,13 @@ import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.util.Base64; import java.util.List; import java.util.function.Consumer; @@ -32,11 +36,21 @@ public String getName() { @Override public void persistState() throws IOException { - // Nothing to persist + throw new UnsupportedOperationException(); } @Override public void persistState(long snapshotTimestampMs, String snapshotId, String snapshotDescription) throws IOException { - // Nothing to persist + throw new UnsupportedOperationException(); + } + + @Override + public void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOException { + byte[] modelBytes = Base64.getDecoder().decode(modelBase64.getBytes(StandardCharsets.UTF_8)); + try (OutputStream restoreStream = processRestoreStream()) { + LengthEncodedWriter lengthEncodedWriter = new LengthEncodedWriter(restoreStream); + lengthEncodedWriter.writeNumFields(modelSizeAfterUnbase64); + restoreStream.write(modelBytes); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java index 91c4669bc0160..72c3e9b8af0d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java @@ -9,6 +9,9 @@ import org.elasticsearch.xpack.ml.process.NativeProcess; +import java.io.IOException; + public interface PyTorchProcess extends NativeProcess { + void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOException; } From fa57945f9298522f104bbeae724696591b69bcbe Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 29 Mar 2021 16:34:30 +0300 Subject: [PATCH 7/7] Address review comments --- .../core/ml/action/StartTrainedModelDeploymentAction.java | 4 ++-- .../xpack/ml/inference/deployment/DeploymentManager.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 81d9083753bab..fd51a40d6643d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -66,7 +66,7 @@ public String getModelId() { } public void setTimeout(TimeValue timeout) { - this.timeout = timeout; + this.timeout = ExceptionsHelper.requireNonNull(timeout, TIMEOUT); } public TimeValue getTimeout() { @@ -117,7 +117,7 @@ public String toString() { public static class TaskParams implements PersistentTaskParams { - public static final Version VERSION_INTRODUCED = Version.V_7_13_0; + public static final Version VERSION_INTRODUCED = Version.V_8_0_0; private final String modelId; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 9a2e595bf0cbd..1425266ef9461 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -53,7 +53,7 @@ public void startDeployment(TrainedModelDeploymentTask task) { } private void doStartDeployment(TrainedModelDeploymentTask task) { - logger.info("[{}] Starting model deployment", task.getModelId()); + logger.debug("[{}] Starting model deployment", task.getModelId()); ProcessContext processContext = new ProcessContext(task.getModelId());