Skip to content

Commit

Permalink
[ML] Start and stop model deployments (#70713)
Browse files Browse the repository at this point in the history
Initial start/stop trained model deployment actions.
  • Loading branch information
dimitris-athanasiou authored Mar 29, 2021
1 parent 6249dbb commit 8ba697b
Show file tree
Hide file tree
Showing 18 changed files with 1,596 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ public final class MlTasks {
public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed";
public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics";
public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade";
public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment";

public static final String JOB_TASK_ID_PREFIX = "job-";
public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-";
public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-";
public static final String JOB_SNAPSHOT_UPGRADE_TASK_ID_PREFIX = "job-snapshot-upgrade-";
public static final String TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX = "trained_model_deployment-";

public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE =
new PersistentTasksCustomMetadata.Assignment(null,
Expand Down Expand Up @@ -76,6 +78,10 @@ public static String dataFrameAnalyticsId(String taskId) {
return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length());
}

public static String trainedModelDeploymentTaskId(String modelId) {
return TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + modelId;
}

@Nullable
public static PersistentTasksCustomMetadata.PersistentTask<?> getJobTask(String jobId, @Nullable PersistentTasksCustomMetadata tasks) {
return tasks == null ? null : tasks.getTask(jobTaskId(jobId));
Expand All @@ -100,6 +106,12 @@ public static PersistentTasksCustomMetadata.PersistentTask<?> getSnapshotUpgrade
return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId));
}

@Nullable
public static PersistentTasksCustomMetadata.PersistentTask<?> getDeployTrainedModelTask(String modelId,
@Nullable PersistentTasksCustomMetadata tasks) {
return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId));
}

/**
* Note that the return value of this method does NOT take node relocations into account.
* Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.persistent.PersistentTaskParams;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

public class StartTrainedModelDeploymentAction extends ActionType<NodeAcknowledgedResponse> {

public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";

public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS);

public StartTrainedModelDeploymentAction() {
super(NAME, NodeAcknowledgedResponse::new);
}

public static class Request extends MasterNodeRequest<Request> implements ToXContentObject {

private static final ParseField MODEL_ID = new ParseField("model_id");
private static final ParseField TIMEOUT = new ParseField("timeout");

private String modelId;
private TimeValue timeout = DEFAULT_TIMEOUT;

public Request(String modelId) {
setModelId(modelId);
}

public Request(StreamInput in) throws IOException {
super(in);
modelId = in.readString();
timeout = in.readTimeValue();
}

public final void setModelId(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
}

public String getModelId() {
return modelId;
}

public void setTimeout(TimeValue timeout) {
this.timeout = ExceptionsHelper.requireNonNull(timeout, TIMEOUT);
}

public TimeValue getTimeout() {
return timeout;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeTimeValue(timeout);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
return builder;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || obj.getClass() != getClass()) {
return false;
}
Request other = (Request) obj;
return Objects.equals(modelId, other.modelId) && Objects.equals(timeout, other.timeout);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public static class TaskParams implements PersistentTaskParams {

public static final Version VERSION_INTRODUCED = Version.V_8_0_0;

private final String modelId;

public TaskParams(String modelId) {
this.modelId = Objects.requireNonNull(modelId);
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
}

public String getModelId() {
return modelId;
}

@Override
public String getWriteableName() {
return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME;
}

@Override
public Version getMinimalSupportedVersion() {
return VERSION_INTRODUCED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId);
}

@Override
public boolean equals(Object o) {
if (o == this) return true;
if (o == null || getClass() != o.getClass()) return false;

TaskParams other = (TaskParams) o;
return Objects.equals(modelId, other.modelId);
}
}

public interface TaskMatcher {

static boolean match(Task task, String expectedId) {
if (task instanceof TaskMatcher) {
if (Strings.isAllOrWildcard(expectedId)) {
return true;
}
String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId;
return expectedDescription.equals(task.getDescription());
}
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class StopTrainedModelDeploymentAction extends ActionType<StopTrainedModelDeploymentAction.Response> {

public static final StopTrainedModelDeploymentAction INSTANCE = new StopTrainedModelDeploymentAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/stop";

public StopTrainedModelDeploymentAction() {
super(NAME, StopTrainedModelDeploymentAction.Response::new);
}

public static class Request extends BaseTasksRequest<Request> implements ToXContentObject {

public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField FORCE = new ParseField("force");

private String id;
private boolean allowNoMatch = true;
private boolean force;

public Request(String id) {
setId(id);
}

public Request(StreamInput in) throws IOException {
super(in);
id = in.readString();
allowNoMatch = in.readBoolean();
force = in.readBoolean();
}

public final void setId(String id) {
this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID);
}

public String getId() {
return id;
}

public void setAllowNoMatch(boolean allowNoMatch) {
this.allowNoMatch = allowNoMatch;
}

public boolean isAllowNoMatch() {
return allowNoMatch;
}

public void setForce(boolean force) {
this.force = force;
}

public boolean isForce() {
return force;
}

@Override
public boolean match(Task task) {
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, id);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(id);
out.writeBoolean(allowNoMatch);
out.writeBoolean(force);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id);
builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch);
builder.field(FORCE.getPreferredName(), force);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(id, allowNoMatch, force);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Request that = (Request) o;
return Objects.equals(id, that.id) &&
allowNoMatch == that.allowNoMatch &&
force == that.force;
}
}

public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {

private final boolean undeployed;

public Response(boolean undeployed) {
super(null, null);
this.undeployed = undeployed;
}

public Response(StreamInput in) throws IOException {
super(in);
undeployed = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(undeployed);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentCommon(builder, params);
builder.field("stopped", undeployed);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(undeployed);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response that = (Response) o;
return undeployed == that.undeployed;
}
}
}
Loading

0 comments on commit 8ba697b

Please sign in to comment.