diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index a495cbe9c77cc..a98f444accc27 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -28,11 +28,13 @@ public final class MlTasks { public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_ID_PREFIX = "job-snapshot-upgrade-"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX = "trained_model_deployment-"; public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetadata.Assignment(null, @@ -76,6 +78,10 @@ public static String dataFrameAnalyticsId(String taskId) { return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length()); } + public static String trainedModelDeploymentTaskId(String modelId) { + return TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + modelId; + } + @Nullable public static PersistentTasksCustomMetadata.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetadata tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -100,6 +106,12 @@ public static PersistentTasksCustomMetadata.PersistentTask getSnapshotUpgrade return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId)); } + @Nullable + public static PersistentTasksCustomMetadata.PersistentTask getDeployTrainedModelTask(String modelId, + @Nullable PersistentTasksCustomMetadata tasks) { + return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..fd51a40d6643d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +public class StartTrainedModelDeploymentAction extends ActionType { + + public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start"; + + public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS); + + public StartTrainedModelDeploymentAction() { + super(NAME, NodeAcknowledgedResponse::new); + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + private static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField TIMEOUT = new ParseField("timeout"); + + private String modelId; + private TimeValue timeout = DEFAULT_TIMEOUT; + + public Request(String modelId) { + setModelId(modelId); + } + + public Request(StreamInput in) throws IOException { + super(in); + modelId = in.readString(); + timeout = in.readTimeValue(); + } + + public final void setModelId(String modelId) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + } + + public String getModelId() { + return modelId; + } + + public void setTimeout(TimeValue timeout) { + this.timeout = ExceptionsHelper.requireNonNull(timeout, TIMEOUT); + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeTimeValue(timeout); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(modelId, timeout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(modelId, other.modelId) && Objects.equals(timeout, other.timeout); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class TaskParams implements PersistentTaskParams { + + public static final Version VERSION_INTRODUCED = Version.V_8_0_0; + + private final String modelId; + + public TaskParams(String modelId) { + this.modelId = Objects.requireNonNull(modelId); + } + + public TaskParams(StreamInput in) throws IOException { + this.modelId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return VERSION_INTRODUCED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + TaskParams other = (TaskParams) o; + return Objects.equals(modelId, other.modelId); + } + } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (Strings.isAllOrWildcard(expectedId)) { + return true; + } + String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..2fd52f5baa5d0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class StopTrainedModelDeploymentAction extends ActionType { + + public static final StopTrainedModelDeploymentAction INSTANCE = new StopTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/stop"; + + public StopTrainedModelDeploymentAction() { + super(NAME, StopTrainedModelDeploymentAction.Response::new); + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + public static final ParseField FORCE = new ParseField("force"); + + private String id; + private boolean allowNoMatch = true; + private boolean force; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + allowNoMatch = in.readBoolean(); + force = in.readBoolean(); + } + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setForce(boolean force) { + this.force = force; + } + + public boolean isForce() { + return force; + } + + @Override + public boolean match(Task task) { + return StartTrainedModelDeploymentAction.TaskMatcher.match(task, id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeBoolean(allowNoMatch); + out.writeBoolean(force); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch); + builder.field(FORCE.getPreferredName(), force); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, allowNoMatch, force); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Request that = (Request) o; + return Objects.equals(id, that.id) && + allowNoMatch == that.allowNoMatch && + force == that.force; + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean undeployed; + + public Response(boolean undeployed) { + super(null, null); + this.undeployed = undeployed; + } + + public Response(StreamInput in) throws IOException { + super(in); + undeployed = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(undeployed); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("stopped", undeployed); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(undeployed); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return undeployed == that.undeployed; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java new file mode 100644 index 0000000000000..b63b903809e3d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum TrainedModelDeploymentState implements Writeable { + + STARTING, STARTED, STOPPING, STOPPED; + + public static TrainedModelDeploymentState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static TrainedModelDeploymentState fromStream(StreamInput in) throws IOException { + return in.readEnum(TrainedModelDeploymentState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java new file mode 100644 index 0000000000000..29641b6b5512b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.io.IOException; +import java.util.Objects; + +public class TrainedModelDeploymentTaskState implements PersistentTaskState { + + public static final String NAME = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + private static ParseField REASON = new ParseField("reason"); + + private final TrainedModelDeploymentState state; + private final long allocationId; + private final String reason; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, + a -> new TrainedModelDeploymentTaskState((TrainedModelDeploymentState) a[0], (long) a[1], (String) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsState::fromString, STATE); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); + } + + public static TrainedModelDeploymentTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public TrainedModelDeploymentTaskState(TrainedModelDeploymentState state, long allocationId, @Nullable String reason) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + this.reason = reason; + } + + public TrainedModelDeploymentTaskState(StreamInput in) throws IOException { + this.state = TrainedModelDeploymentState.fromStream(in); + this.allocationId = in.readLong(); + this.reason = in.readOptionalString(); + } + + public TrainedModelDeploymentState getState() { + return state; + } + + public String getReason() { + return reason; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + if (reason != null) { + builder.field(REASON.getPreferredName(), reason); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + out.writeOptionalString(reason); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDeploymentTaskState that = (TrainedModelDeploymentTaskState) o; + return allocationId == that.allocationId && + state == that.state && + Objects.equals(reason, that.reason); + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId, reason); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 9b36d75daedee..24455e5c9b34d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; @@ -141,6 +142,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -157,6 +159,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -175,6 +178,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportExplainDataFrameAnalyticsAction; @@ -221,6 +225,7 @@ import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; +import org.elasticsearch.xpack.ml.action.TransportStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDatafeedAction; @@ -252,10 +257,13 @@ import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -325,10 +333,12 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -531,6 +541,7 @@ public Set getRoles() { private final SetOnce inferenceModelBreaker = new SetOnce<>(); private final SetOnce modelLoadingService = new SetOnce<>(); private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); + private final SetOnce deploymentManager = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -693,6 +704,7 @@ public Collection createComponents(Client client, ClusterService cluster final NormalizerProcessFactory normalizerProcessFactory; final AnalyticsProcessFactory analyticsProcessFactory; final AnalyticsProcessFactory memoryEstimationProcessFactory; + final PyTorchProcessFactory pyTorchProcessFactory; if (MachineLearningField.AUTODETECT_PROCESS.get(settings)) { try { NativeController nativeController = @@ -714,6 +726,7 @@ public Collection createComponents(Client client, ClusterService cluster dataFrameAnalyticsAuditor); memoryEstimationProcessFactory = new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService); + pyTorchProcessFactory = new NativePyTorchProcessFactory(environment, nativeController, clusterService); mlController = nativeController; } catch (IOException e) { // The low level cause of failure from the named pipe helper's perspective is almost never the real root cause, so @@ -733,6 +746,7 @@ public Collection createComponents(Client client, ClusterService cluster normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; + pyTorchProcessFactory = (jobId, executorService, onProcessCrash) -> null; } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); @@ -773,6 +787,7 @@ public Collection createComponents(Client client, ClusterService cluster clusterService.getNodeName(), inferenceModelBreaker.get()); this.modelLoadingService.set(modelLoadingService); + this.deploymentManager.set(new DeploymentManager(client, threadPool, pyTorchProcessFactory)); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( @@ -877,7 +892,12 @@ public List> getPersistentTasksExecutor(ClusterServic autodetectProcessManager.get(), memoryTracker.get(), expressionResolver, - client) + client), + new TransportStartTrainedModelDeploymentAction.TaskExecutor(settings, + clusterService, + expressionResolver, + memoryTracker.get(), + deploymentManager.get()) ); } @@ -953,6 +973,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutTrainedModelAliasAction(), new RestDeleteTrainedModelAliasAction(), new RestPreviewDataFrameAnalyticsAction(), + new RestStartTrainedModelDeploymentAction(), + new RestStopTrainedModelDeploymentAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1039,6 +1061,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class), new ActionHandler<>(DeleteTrainedModelAliasAction.INSTANCE, TransportDeleteTrainedModelAliasAction.class), new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class), + new ActionHandler<>(StartTrainedModelDeploymentAction.INSTANCE, TransportStartTrainedModelDeploymentAction.class), + new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class), usageAction, infoAction); } @@ -1173,6 +1197,8 @@ public List getNamedWriteables() { StartDataFrameAnalyticsAction.TaskParams::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, SnapshotUpgradeTaskParams::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + StartTrainedModelDeploymentAction.TaskParams::new)); // Persistent task states namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new)); @@ -1182,6 +1208,8 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, SnapshotUpgradeTaskState.NAME, SnapshotUpgradeTaskState::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, + TrainedModelDeploymentTaskState.NAME, TrainedModelDeploymentTaskState::new)); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..13fb27df529a7 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -0,0 +1,305 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; +import org.elasticsearch.xpack.ml.job.JobNodeSelector; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Predicate; + +public class TransportStartTrainedModelDeploymentAction + extends TransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportStartTrainedModelDeploymentAction.class); + + private final XPackLicenseState licenseState; + private final Client client; + private final PersistentTasksService persistentTasksService; + + @Inject + public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, + IndexNameExpressionResolver indexNameExpressionResolver, + PersistentTasksService persistentTasksService) { + super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters, + StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, + ThreadPool.Names.SAME); + this.licenseState = Objects.requireNonNull(licenseState); + this.client = Objects.requireNonNull(client); + this.persistentTasksService = Objects.requireNonNull(persistentTasksService); + } + + @Override + protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Request request, ClusterState state, + ActionListener listener) throws Exception { + logger.debug(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); + if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + ActionListener> waitForDeploymentToStart = + ActionListener.wrap( + startedTask -> waitForDeploymentStarted(startedTask, request.getTimeout(), listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + e = new ElasticsearchStatusException( + "Cannot start deployment [{}] because it has already been started", + RestStatus.CONFLICT, + e, + request.getModelId() + ); + } + listener.onFailure(e); + } + ); + + ActionListener getModelListener = ActionListener.wrap( + getModelResponse -> { + if (getModelResponse.getResources().results().size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException( + "cannot deploy more than one models at the same time; [{}] matches [{}] models]", + request.getModelId(), getModelResponse.getResources().results().size())); + return; + } + persistentTasksService.sendStartRequest( + MlTasks.trainedModelDeploymentTaskId(request.getModelId()), + MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + new TaskParams(request.getModelId()), + waitForDeploymentToStart + ); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request( + request.getModelId(), null, Collections.emptySet()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + private void waitForDeploymentStarted(PersistentTasksCustomMetadata.PersistentTask task, + TimeValue timeout, ActionListener listener) { + DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(); + persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, timeout, + new PersistentTasksService.WaitForPersistentTaskListener() { + @Override + public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (predicate.exception != null) { + cancelDeploymentStart(task, predicate.exception, listener); + } else { + listener.onResponse(new NodeAcknowledgedResponse(true, predicate.node)); + } + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void cancelDeploymentStart( + PersistentTasksCustomMetadata.PersistentTask persistentTask, Exception exception, + ActionListener listener) { + persistentTasksService.sendRemoveRequest(persistentTask.getId(), ActionListener.wrap( + pTask -> listener.onFailure(exception), + e -> { + logger.error( + new ParameterizedMessage("[{}] Failed to cancel persistent task that could not be assigned due to [{}]", + persistentTask.getParams().getModelId(), exception.getMessage()), + e + ); + listener.onFailure(exception); + } + )); + + } + + @Override + protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) { + // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, + // then delegating to PersistentTasksService doesn't make a whole lot of sense, + // because PersistentTasksService will then fail. + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + private static class DeploymentStartedPredicate implements Predicate> { + + private volatile Exception exception; + private volatile String node = ""; + private volatile String assignmentExplanation; + + @Override + public boolean test(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (persistentTask == null) { + return false; + } + + PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment(); + + String reason = "__unknown__"; + + if (assignment != null) { + if (assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { + return true; + } + if (assignment.equals(PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) { + exception = new ElasticsearchStatusException("Could not start trained model deployment, allocation explanation [{}]", + RestStatus.TOO_MANY_REQUESTS, assignment.getExplanation()); + return true; + } + } + + TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) persistentTask.getState(); + reason = taskState != null ? taskState.getReason() : reason; + TrainedModelDeploymentState deploymentState = taskState == null ? TrainedModelDeploymentState.STARTED : taskState.getState(); + switch (deploymentState) { + case STARTED: + node = persistentTask.getExecutorNode(); + return true; + case STARTING: + case STOPPING: + case STOPPED: + return false; + default: + exception = ExceptionsHelper.serverError("Unexpected task state [{}] with reason [{}] while waiting to be started", + taskState.getState(), reason); + return true; + } + } + } + + public static class TaskExecutor extends AbstractJobPersistentTasksExecutor { + + private final DeploymentManager manager; + + public TaskExecutor(Settings settings, ClusterService clusterService, IndexNameExpressionResolver expressionResolver, + MlMemoryTracker memoryTracker, DeploymentManager manager) { + super(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + MachineLearning.UTILITY_THREAD_POOL_NAME, + settings, + clusterService, + memoryTracker, + expressionResolver); + this.manager = Objects.requireNonNull(manager); + } + + @Override + protected AllocatedPersistentTask createTask( + long id, String type, String action, TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask persistentTask, + Map headers) { + return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + } + + @Override + public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params, ClusterState clusterState) { + JobNodeSelector jobNodeSelector = + new JobNodeSelector( + clusterState, + params.getModelId(), + MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + memoryTracker, + 0, + node -> nodeFilter(node, params)); + PersistentTasksCustomMetadata.Assignment assignment = jobNodeSelector.selectNode( + maxOpenJobs, + Integer.MAX_VALUE, + maxMachineMemoryPercent, + maxNodeMemory, + false, + useAutoMemoryPercentage + ); + return assignment; + } + + public static String nodeFilter(DiscoveryNode node, TaskParams params) { + String id = params.getModelId(); + + if (node.getVersion().before(TaskParams.VERSION_INTRODUCED)) { + return "Not opening job [" + id + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because the data frame analytics requires a node of version [" + + TaskParams.VERSION_INTRODUCED + "] or higher"; + } + + return null; + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) { + TrainedModelDeploymentTask trainedModelDeploymentTask = (TrainedModelDeploymentTask) task; + trainedModelDeploymentTask.setDeploymentManager(manager); + + TrainedModelDeploymentTaskState deployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTING, task.getAllocationId(), null); + task.updatePersistentTaskState(deployingState, ActionListener.wrap( + response -> manager.startDeployment(trainedModelDeploymentTask), + task::markAsFailed + )); + } + + @Override + protected String[] indicesOfInterest(TaskParams params) { + return new String[] { + InferenceIndexConstants.INDEX_PATTERN + }; + } + + @Override + protected String getJobId(TaskParams params) { + return params.getModelId(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..6d9722eef9aa5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -0,0 +1,201 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.discovery.MasterNotDiscoveredException; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +public class TransportStopTrainedModelDeploymentAction extends TransportTasksAction { + + private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); + + private final Client client; + private final ThreadPool threadPool; + private final PersistentTasksService persistentTasksService; + + @Inject + public TransportStopTrainedModelDeploymentAction(String actionName, ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters, Client client, ThreadPool threadPool, + PersistentTasksService persistentTasksService) { + super(actionName, clusterService, transportService, actionFilters, StopTrainedModelDeploymentAction.Request::new, + StopTrainedModelDeploymentAction.Response::new, StopTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); + this.client = client; + this.threadPool = threadPool; + this.persistentTasksService = persistentTasksService; + } + + @Override + protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + ClusterState state = clusterService.state(); + DiscoveryNodes nodes = state.nodes(); + if (nodes.isLocalNodeElectedMaster() == false) { + redirectToMasterNode(nodes.getMasterNode(), request, listener); + return; + } + + logger.debug("[{}] Received request to undeploy", request.getId()); + + ActionListener getModelListener = ActionListener.wrap( + getModelsResponse -> { + List models = getModelsResponse.getResources().results(); + if (models.isEmpty()) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; + } + if (models.size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException("cannot undeploy multiple models at the same time")); + return; + } + + ClusterState clusterState = clusterService.state(); + PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = + MlTasks.getDeployTrainedModelTask(request.getId(), tasks); + if (deployTrainedModelTask == null) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; + } + normalUndeploy(task, deployTrainedModelTask, request, listener); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request( + request.getId(), null, Collections.emptySet()); + getModelRequest.setAllowNoResources(request.isAllowNoMatch()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + private void redirectToMasterNode(DiscoveryNode masterNode, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + if (masterNode == null) { + listener.onFailure(new MasterNotDiscoveredException()); + } else { + transportService.sendRequest(masterNode, actionName, request, + new ActionListenerResponseHandler<>(listener, StopTrainedModelDeploymentAction.Response::new)); + } + } + + private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask, + StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + request.setNodes(deployTrainedModelTask.getExecutorNode()); + + ActionListener finalListener = ActionListener.wrap( + r -> waitForTaskRemoved(Collections.singleton(deployTrainedModelTask.getId()), request, r, listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { + // A node has dropped out of the cluster since we started executing the requests. + // Since undeploying an already undeployed trained model is not an error we can try again. + // The tasks that were running on the node that dropped out of the cluster + // will just have their persistent tasks cancelled. Tasks that were stopped + // by the previous attempt will be noops in the subsequent attempt. + doExecute(task, request, listener); + } else { + listener.onFailure(e); + } + } + ); + + super.doExecute(task, request, finalListener); + } + + void waitForTaskRemoved(Set taskIds, StopTrainedModelDeploymentAction.Request request, + StopTrainedModelDeploymentAction.Response response, + ActionListener listener) { + persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> + persistentTasks.findTasks(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), + request.getTimeout(), ActionListener.wrap( + booleanResponse -> { + listener.onResponse(response); + }, + listener::onFailure + ) + ); + } + + @Override + protected StopTrainedModelDeploymentAction.Response newResponse(StopTrainedModelDeploymentAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + return new StopTrainedModelDeploymentAction.Response(true); + } + } + + @Override + protected void taskOperation(StopTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, + ActionListener listener) { + TrainedModelDeploymentTaskState undeployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STOPPING, task.getAllocationId(), "api"); + task.updatePersistentTaskState(undeployingState, ActionListener.wrap( + updatedTask -> { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() throws Exception { + task.stop("undeploy_trained_model (api)"); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + } + }); + }, + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + // the task has disappeared so must have stopped + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + } else { + listener.onFailure(e); + } + } + )); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java new file mode 100644 index 0000000000000..1425266ef9461 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public class DeploymentManager { + + private static final Logger logger = LogManager.getLogger(DeploymentManager.class); + + private final Client client; + private final PyTorchProcessFactory pyTorchProcessFactory; + private final ExecutorService executorServiceForDeployment; + private final ExecutorService executorServiceForProcess; + private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); + + public DeploymentManager(Client client, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) { + this.client = Objects.requireNonNull(client); + this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory); + this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME); + this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); + } + + public void startDeployment(TrainedModelDeploymentTask task) { + executorServiceForDeployment.execute(() -> doStartDeployment(task)); + } + + private void doStartDeployment(TrainedModelDeploymentTask task) { + logger.debug("[{}] Starting model deployment", task.getModelId()); + + ProcessContext processContext = new ProcessContext(task.getModelId()); + + if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { + throw ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId()); + } + + processContext.startProcess(); + + try { + processContext.loadModel(); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] error loading model", task.getModelId()), e); + } + + TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); + task.updatePersistentTaskState(startedState, ActionListener.wrap( + response -> logger.info("[{}] trained model deployment started", task.getModelId()), + task::markAsFailed + )); + } + + public void stopDeployment(TrainedModelDeploymentTask task) { + ProcessContext processContext; + synchronized (processContextByAllocation) { + processContext = processContextByAllocation.get(task.getAllocationId()); + } + if (processContext != null) { + logger.debug("[{}] Stopping deployment", task.getModelId()); + processContext.killProcess(); + } else { + logger.debug("[{}] No process context to stop", task.getModelId()); + } + } + + class ProcessContext { + + private final String modelId; + private final SetOnce process = new SetOnce<>(); + + ProcessContext(String modelId) { + this.modelId = Objects.requireNonNull(modelId); + } + + synchronized void startProcess() { + process.set(pyTorchProcessFactory.createProcess(modelId, executorServiceForProcess, onProcessCrash())); + } + + synchronized void killProcess() { + if (process.get() == null) { + return; + } + try { + process.get().kill(true); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] Failed to kill process", modelId), e); + } + } + + private Consumer onProcessCrash() { + return reason -> { + logger.error("[{}] process crashed due to reason [{}]", modelId, reason); + }; + } + + void loadModel() throws IOException { + // Here we should be reading the model location from the deployment config. + // Hardcoding this for the prototype. + String index = "test-models"; + String docId = "simple-model"; + + GetResponse modelGetResponse = client.get(new GetRequest(index, docId)).actionGet(); + if (modelGetResponse.isExists() == false) { + throw ExceptionsHelper.badRequestException("[{}] no model was found", modelId); + } + Map sourceAsMap = modelGetResponse.getSourceAsMap(); + int modelSizeAfterUnbase64 = (int) sourceAsMap.get("size"); + String modelBase64 = (String) sourceAsMap.get("model"); + process.get().loadModel(modelBase64, modelSizeAfterUnbase64); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java new file mode 100644 index 0000000000000..aa4f7f8d93998 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; + +import java.util.Map; + +public class TrainedModelDeploymentTask extends AllocatedPersistentTask implements StartTrainedModelDeploymentAction.TaskMatcher { + + private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); + + private final TaskParams params; + private volatile boolean isStopping; + private volatile DeploymentManager manager; + + public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map headers, + TaskParams taskParams) { + super(id, type, action, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); + this.params = taskParams; + } + + public String getModelId() { + return params.getModelId(); + } + + public void stop(String reason) { + isStopping = true; + logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + + assert manager != null : "manager should not be unset when stop is called"; + manager.stopDeployment(this); + markAsCompleted(); + } + + public void setDeploymentManager(DeploymentManager manager) { + this.manager = manager; + } + + @Override + protected void onCancelled() { + String reason = getReasonCancelled(); + stop(reason); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java new file mode 100644 index 0000000000000..fdb65304f5571 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Base64; +import java.util.List; +import java.util.function.Consumer; + +public class NativePyTorchProcess extends AbstractNativeProcess implements PyTorchProcess { + + private static final String NAME = "pytorch_inference"; + + protected NativePyTorchProcess(String jobId, NativeController nativeController, ProcessPipes processPipes, int numberOfFields, + List filesToDelete, Consumer onProcessCrash) { + super(jobId, nativeController, processPipes, numberOfFields, filesToDelete, onProcessCrash); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void persistState(long snapshotTimestampMs, String snapshotId, String snapshotDescription) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOException { + byte[] modelBytes = Base64.getDecoder().decode(modelBase64.getBytes(StandardCharsets.UTF_8)); + try (OutputStream restoreStream = processRestoreStream()) { + LengthEncodedWriter lengthEncodedWriter = new LengthEncodedWriter(restoreStream); + lengthEncodedWriter.writeNumFields(modelSizeAfterUnbase64); + restoreStream.write(modelBytes); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java new file mode 100644 index 0000000000000..6c9b050e8e0e6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.env.Environment; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; + +import java.io.IOException; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public class NativePyTorchProcessFactory implements PyTorchProcessFactory { + + private static final Logger logger = LogManager.getLogger(NativePyTorchProcessFactory.class); + + private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper(); + + private final Environment env; + private final NativeController nativeController; + private volatile Duration processConnectTimeout; + + public NativePyTorchProcessFactory(Environment env, + NativeController nativeController, + ClusterService clusterService) { + this.env = Objects.requireNonNull(env); + this.nativeController = Objects.requireNonNull(nativeController); + setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings())); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT, + this::setProcessConnectTimeout); + } + + void setProcessConnectTimeout(TimeValue processConnectTimeout) { + this.processConnectTimeout = Duration.ofMillis(processConnectTimeout.getMillis()); + } + + @Override + public PyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash) { + List filesToDelete = new ArrayList<>(); + ProcessPipes processPipes = new ProcessPipes( + env, + NAMED_PIPE_HELPER, + processConnectTimeout, + PyTorchBuilder.PROCESS_NAME, + modelId, + null, + false, + true, + true, + true, + false + ); + + executeProcess(processPipes, filesToDelete); + + NativePyTorchProcess process = new NativePyTorchProcess(modelId, nativeController, processPipes, 0, filesToDelete, onProcessCrash); + + try { + process.start(executorService); + } catch(IOException | EsRejectedExecutionException e) { + String msg = "Failed to connect to pytorch process for job " + modelId; + logger.error(msg); + try { + IOUtils.close(process); + } catch (IOException ioe) { + logger.error("Can't close pytorch process", ioe); + } + throw ExceptionsHelper.serverError(msg, e); + } + return process; + } + + private void executeProcess(ProcessPipes processPipes, List filesToDelete) { + PyTorchBuilder pyTorchBuilder = new PyTorchBuilder(env::tmpFile, nativeController, processPipes, filesToDelete); + try { + pyTorchBuilder.build(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + throw ExceptionsHelper.serverError("Failed to launch PyTorch process"); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java new file mode 100644 index 0000000000000..60f50a8f70507 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class PyTorchBuilder { + + public static final String PROCESS_NAME = "pytorch_inference"; + private static final String PROCESS_PATH = "./" + PROCESS_NAME; + + private final Supplier tempDirPathSupplier; + private final NativeController nativeController; + private final ProcessPipes processPipes; + private final List filesToDelete; + + public PyTorchBuilder(Supplier tempDirPathSupplier, NativeController nativeController, ProcessPipes processPipes, + List filesToDelete) { + this.tempDirPathSupplier = Objects.requireNonNull(tempDirPathSupplier); + this.nativeController = Objects.requireNonNull(nativeController); + this.processPipes = Objects.requireNonNull(processPipes); + this.filesToDelete = Objects.requireNonNull(filesToDelete); + } + + public void build() throws IOException, InterruptedException { + List command = buildCommand(); + processPipes.addArgs(command); + nativeController.startProcess(command); + } + + private List buildCommand() { + List command = new ArrayList<>(); + command.add(PROCESS_PATH); + return command; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java new file mode 100644 index 0000000000000..72c3e9b8af0d8 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.NativeProcess; + +import java.io.IOException; + +public interface PyTorchProcess extends NativeProcess { + + void loadModel(String modelBase64, int modelSizeAfterUnbase64) throws IOException; +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java new file mode 100644 index 0000000000000..4d22a80f433a5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public interface PyTorchProcessFactory { + + PyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java new file mode 100644 index 0000000000000..c812e490217ed --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class PyTorchProcessManager { + + private static final Logger logger = LogManager.getLogger(PyTorchProcessManager.class); + + public PyTorchProcessManager() { + + } + + public void start(String taskId) { + + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..be6dc1fb7615d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; + +public class RestStartTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_start_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route(POST, + MachineLearning.BASE_PATH + "trained_models/deployment/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_start")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + StartTrainedModelDeploymentAction.Request request = new StartTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..c92751811a911 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestStopTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_stop_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route( + POST, + BASE_PATH + "trained_models/deployment/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stop") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StopTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +}