-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Adding new trained model allocation service (#75778)
Adds a new service for trained model allocation to nodes. Initially, this only supports PyTorch models and simply allocates to nodes with the ML roles. Design is fairly simple: - A master node service runs allowing for new allocations to be created/updated/deleted from cluster state - A node service runs listening to updates referencing the local node + any models it may have allocated and updates accordingly. This type of service sort of splits the difference between the logic of shard allocation and persistent tasks. Neither really fully addressed the need here.
- Loading branch information
Showing
37 changed files
with
3,964 additions
and
314 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
133 changes: 133 additions & 0 deletions
133
.../main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.action; | ||
|
||
import org.elasticsearch.action.ActionRequestValidationException; | ||
import org.elasticsearch.action.ActionResponse; | ||
import org.elasticsearch.action.ActionType; | ||
import org.elasticsearch.action.support.master.MasterNodeRequest; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.ParseField; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class CreateTrainedModelAllocationAction extends ActionType<CreateTrainedModelAllocationAction.Response> { | ||
public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction(); | ||
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create"; | ||
|
||
private CreateTrainedModelAllocationAction() { | ||
super(NAME, CreateTrainedModelAllocationAction.Response::new); | ||
} | ||
|
||
public static class Request extends MasterNodeRequest<Request> { | ||
private final StartTrainedModelDeploymentAction.TaskParams taskParams; | ||
|
||
public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) { | ||
this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams"); | ||
} | ||
|
||
public Request(StreamInput in) throws IOException { | ||
super(in); | ||
this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
taskParams.writeTo(out); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Request request = (Request) o; | ||
return Objects.equals(taskParams, request.taskParams); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(taskParams); | ||
} | ||
|
||
public StartTrainedModelDeploymentAction.TaskParams getTaskParams() { | ||
return taskParams; | ||
} | ||
} | ||
|
||
public static class Response extends ActionResponse implements ToXContentObject { | ||
|
||
private static final ParseField ALLOCATION = new ParseField("allocation"); | ||
|
||
private static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>( | ||
"create_trained_model_allocation_response", | ||
a -> new Response((TrainedModelAllocation) a[0]) | ||
); | ||
static { | ||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION); | ||
} | ||
static Response fromXContent(XContentParser parser) { | ||
return PARSER.apply(parser, null); | ||
} | ||
|
||
private final TrainedModelAllocation trainedModelAllocation; | ||
|
||
public Response(TrainedModelAllocation trainedModelAllocation) { | ||
this.trainedModelAllocation = trainedModelAllocation; | ||
} | ||
|
||
public Response(StreamInput in) throws IOException { | ||
super(in); | ||
this.trainedModelAllocation = new TrainedModelAllocation(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
trainedModelAllocation.writeTo(out); | ||
} | ||
|
||
public TrainedModelAllocation getTrainedModelAllocation() { | ||
return trainedModelAllocation; | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field("allocation", trainedModelAllocation); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Response response = (Response) o; | ||
return Objects.equals(trainedModelAllocation, response.trainedModelAllocation); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(trainedModelAllocation); | ||
} | ||
} | ||
|
||
} |
70 changes: 70 additions & 0 deletions
70
.../main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.action; | ||
|
||
import org.elasticsearch.action.ActionRequestValidationException; | ||
import org.elasticsearch.action.ActionType; | ||
import org.elasticsearch.action.support.master.AcknowledgedResponse; | ||
import org.elasticsearch.action.support.master.MasterNodeRequest; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class DeleteTrainedModelAllocationAction extends ActionType<AcknowledgedResponse> { | ||
public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction(); | ||
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete"; | ||
|
||
private DeleteTrainedModelAllocationAction() { | ||
super(NAME, AcknowledgedResponse::readFrom); | ||
} | ||
|
||
public static class Request extends MasterNodeRequest<Request> { | ||
private final String modelId; | ||
|
||
public Request(String modelId) { | ||
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); | ||
} | ||
|
||
public Request(StreamInput in) throws IOException { | ||
super(in); | ||
this.modelId = in.readString(); | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(modelId); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Request request = (Request) o; | ||
return Objects.equals(modelId, request.modelId); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelId); | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
.../java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.action; | ||
|
||
import org.elasticsearch.action.ActionRequestValidationException; | ||
import org.elasticsearch.action.ActionType; | ||
import org.elasticsearch.action.support.master.AcknowledgedResponse; | ||
import org.elasticsearch.action.support.master.MasterNodeRequest; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class UpdateTrainedModelAllocationStateAction extends ActionType<AcknowledgedResponse> { | ||
public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction(); | ||
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update"; | ||
|
||
private UpdateTrainedModelAllocationStateAction() { | ||
super(NAME, AcknowledgedResponse::readFrom); | ||
} | ||
|
||
public static class Request extends MasterNodeRequest<Request> { | ||
private final String nodeId; | ||
private final String modelId; | ||
private final RoutingStateAndReason routingState; | ||
|
||
public Request(String nodeId, String modelId, RoutingStateAndReason routingState) { | ||
this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id"); | ||
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); | ||
this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state"); | ||
} | ||
|
||
public Request(StreamInput in) throws IOException { | ||
super(in); | ||
this.nodeId = in.readString(); | ||
this.modelId = in.readString(); | ||
this.routingState = new RoutingStateAndReason(in); | ||
} | ||
|
||
public String getNodeId() { | ||
return nodeId; | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
public RoutingStateAndReason getRoutingState() { | ||
return routingState; | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(nodeId); | ||
out.writeString(modelId); | ||
routingState.writeTo(out); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Request request = (Request) o; | ||
return Objects.equals(nodeId, request.nodeId) | ||
&& Objects.equals(modelId, request.modelId) | ||
&& Objects.equals(routingState, request.routingState); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(nodeId, modelId, routingState); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}'; | ||
} | ||
} | ||
|
||
} |
24 changes: 24 additions & 0 deletions
24
...e/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.allocation; | ||
|
||
import java.util.Locale; | ||
|
||
public enum AllocationState { | ||
STARTED, | ||
STOPPING; | ||
|
||
public static AllocationState fromString(String value) { | ||
return valueOf(value.toUpperCase(Locale.ROOT)); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return name().toLowerCase(Locale.ROOT); | ||
} | ||
} |
Oops, something went wrong.