Skip to content

Commit

Permalink
[ML] Adding new trained model allocation service (#75778)
Browse files Browse the repository at this point in the history
Adds a new service for trained model allocation to nodes. 

Initially, this only supports PyTorch models and simply allocates
to nodes with the ML roles. 

Design is fairly simple:

 - A master node service runs allowing for new allocations to be created/updated/deleted from cluster state
 - A node service runs listening to updates referencing the local node + any models it may have allocated and updates accordingly.

This type of service sort of splits the difference between the logic of shard allocation and persistent tasks. Neither really fully addressed the need here.
  • Loading branch information
benwtrent authored Aug 3, 2021
1 parent 10a1d27 commit b11c15b
Show file tree
Hide file tree
Showing 37 changed files with 3,964 additions and 314 deletions.
1 change: 1 addition & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class CreateTrainedModelAllocationAction extends ActionType<CreateTrainedModelAllocationAction.Response> {
public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction();
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create";

private CreateTrainedModelAllocationAction() {
super(NAME, CreateTrainedModelAllocationAction.Response::new);
}

public static class Request extends MasterNodeRequest<Request> {
private final StartTrainedModelDeploymentAction.TaskParams taskParams;

public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) {
this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams");
}

public Request(StreamInput in) throws IOException {
super(in);
this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
taskParams.writeTo(out);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(taskParams, request.taskParams);
}

@Override
public int hashCode() {
return Objects.hash(taskParams);
}

public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
return taskParams;
}
}

public static class Response extends ActionResponse implements ToXContentObject {

private static final ParseField ALLOCATION = new ParseField("allocation");

private static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
"create_trained_model_allocation_response",
a -> new Response((TrainedModelAllocation) a[0])
);
static {
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION);
}
static Response fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final TrainedModelAllocation trainedModelAllocation;

public Response(TrainedModelAllocation trainedModelAllocation) {
this.trainedModelAllocation = trainedModelAllocation;
}

public Response(StreamInput in) throws IOException {
super(in);
this.trainedModelAllocation = new TrainedModelAllocation(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
trainedModelAllocation.writeTo(out);
}

public TrainedModelAllocation getTrainedModelAllocation() {
return trainedModelAllocation;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("allocation", trainedModelAllocation);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return Objects.equals(trainedModelAllocation, response.trainedModelAllocation);
}

@Override
public int hashCode() {
return Objects.hash(trainedModelAllocation);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class DeleteTrainedModelAllocationAction extends ActionType<AcknowledgedResponse> {
public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction();
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete";

private DeleteTrainedModelAllocationAction() {
super(NAME, AcknowledgedResponse::readFrom);
}

public static class Request extends MasterNodeRequest<Request> {
private final String modelId;

public Request(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
}

public String getModelId() {
return modelId;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId);
}

@Override
public int hashCode() {
return Objects.hash(modelId);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -31,15 +35,15 @@
import java.util.Objects;
import java.util.concurrent.TimeUnit;

public class StartTrainedModelDeploymentAction extends ActionType<NodeAcknowledgedResponse> {
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAllocationAction.Response> {

public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";

public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS);

public StartTrainedModelDeploymentAction() {
super(NAME, NodeAcknowledgedResponse::new);
super(NAME, CreateTrainedModelAllocationAction.Response::new);
}

public static class Request extends MasterNodeRequest<Request> implements ToXContentObject {
Expand Down Expand Up @@ -120,9 +124,29 @@ public String toString() {

public static class TaskParams implements PersistentTaskParams, MlTaskParams {

public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
// TODO add support for other roles? If so, it may have to be an instance method...
// NOTE, whatever determines allocation should not be dynamically set on the node
// Otherwise allocation logic might fail
public static boolean mayAllocateToNode(DiscoveryNode node) {
return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
}

public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2])
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
}

public static TaskParams fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

/**
* This has been found to be approximately 300MB on linux by manual testing.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class UpdateTrainedModelAllocationStateAction extends ActionType<AcknowledgedResponse> {
public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction();
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";

private UpdateTrainedModelAllocationStateAction() {
super(NAME, AcknowledgedResponse::readFrom);
}

public static class Request extends MasterNodeRequest<Request> {
private final String nodeId;
private final String modelId;
private final RoutingStateAndReason routingState;

public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
}

public Request(StreamInput in) throws IOException {
super(in);
this.nodeId = in.readString();
this.modelId = in.readString();
this.routingState = new RoutingStateAndReason(in);
}

public String getNodeId() {
return nodeId;
}

public String getModelId() {
return modelId;
}

public RoutingStateAndReason getRoutingState() {
return routingState;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(nodeId);
out.writeString(modelId);
routingState.writeTo(out);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(nodeId, request.nodeId)
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(routingState, request.routingState);
}

@Override
public int hashCode() {
return Objects.hash(nodeId, modelId, routingState);
}

@Override
public String toString() {
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.allocation;

import java.util.Locale;

public enum AllocationState {
STARTED,
STOPPING;

public static AllocationState fromString(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
Loading

0 comments on commit b11c15b

Please sign in to comment.