diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 5cf789707c58c..3f3eb5218afed 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -5,5 +5,6 @@
+
\ No newline at end of file
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java
new file mode 100644
index 0000000000000..ccf567e76a559
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java
@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class CreateTrainedModelAllocationAction extends ActionType {
+ public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create";
+
+ private CreateTrainedModelAllocationAction() {
+ super(NAME, CreateTrainedModelAllocationAction.Response::new);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+
+ public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ taskParams.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(taskParams, request.taskParams);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(taskParams);
+ }
+
+ public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
+ return taskParams;
+ }
+ }
+
+ public static class Response extends ActionResponse implements ToXContentObject {
+
+ private static final ParseField ALLOCATION = new ParseField("allocation");
+
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "create_trained_model_allocation_response",
+ a -> new Response((TrainedModelAllocation) a[0])
+ );
+ static {
+ PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION);
+ }
+ static Response fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
+
+ private final TrainedModelAllocation trainedModelAllocation;
+
+ public Response(TrainedModelAllocation trainedModelAllocation) {
+ this.trainedModelAllocation = trainedModelAllocation;
+ }
+
+ public Response(StreamInput in) throws IOException {
+ super(in);
+ this.trainedModelAllocation = new TrainedModelAllocation(in);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ trainedModelAllocation.writeTo(out);
+ }
+
+ public TrainedModelAllocation getTrainedModelAllocation() {
+ return trainedModelAllocation;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field("allocation", trainedModelAllocation);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Response response = (Response) o;
+ return Objects.equals(trainedModelAllocation, response.trainedModelAllocation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelAllocation);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java
new file mode 100644
index 0000000000000..589ae631dece8
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class DeleteTrainedModelAllocationAction extends ActionType {
+ public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete";
+
+ private DeleteTrainedModelAllocationAction() {
+ super(NAME, AcknowledgedResponse::readFrom);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final String modelId;
+
+ public Request(String modelId) {
+ this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.modelId = in.readString();
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeString(modelId);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(modelId, request.modelId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(modelId);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
index 6e0c4d517d1b7..fb41c1d92a5ec 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
@@ -11,11 +11,15 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -31,7 +35,7 @@
import java.util.Objects;
import java.util.concurrent.TimeUnit;
-public class StartTrainedModelDeploymentAction extends ActionType {
+public class StartTrainedModelDeploymentAction extends ActionType {
public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";
@@ -39,7 +43,7 @@ public class StartTrainedModelDeploymentAction extends ActionType implements ToXContentObject {
@@ -120,9 +124,29 @@ public String toString() {
public static class TaskParams implements PersistentTaskParams, MlTaskParams {
- public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
+ // TODO add support for other roles? If so, it may have to be an instance method...
+ // NOTE, whatever determines allocation should not be dynamically set on the node
+ // Otherwise allocation logic might fail
+ public static boolean mayAllocateToNode(DiscoveryNode node) {
+ return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
+ }
+ public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_deployment_params",
+ true,
+ a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2])
+ );
+ static {
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX);
+ PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
+ }
+
+ public static TaskParams fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
/**
* This has been found to be approximately 300MB on linux by manual testing.
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java
new file mode 100644
index 0000000000000..e1ae0e6c2258c
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class UpdateTrainedModelAllocationStateAction extends ActionType {
+ public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";
+
+ private UpdateTrainedModelAllocationStateAction() {
+ super(NAME, AcknowledgedResponse::readFrom);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final String nodeId;
+ private final String modelId;
+ private final RoutingStateAndReason routingState;
+
+ public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
+ this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
+ this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
+ this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.nodeId = in.readString();
+ this.modelId = in.readString();
+ this.routingState = new RoutingStateAndReason(in);
+ }
+
+ public String getNodeId() {
+ return nodeId;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public RoutingStateAndReason getRoutingState() {
+ return routingState;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeString(nodeId);
+ out.writeString(modelId);
+ routingState.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(nodeId, request.nodeId)
+ && Objects.equals(modelId, request.modelId)
+ && Objects.equals(routingState, request.routingState);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(nodeId, modelId, routingState);
+ }
+
+ @Override
+ public String toString() {
+ return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
new file mode 100644
index 0000000000000..c9ef574d39f8f
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import java.util.Locale;
+
+public enum AllocationState {
+ STARTED,
+ STOPPING;
+
+ public static AllocationState fromString(String value) {
+ return valueOf(value.toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
new file mode 100644
index 0000000000000..865a490cbf64a
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import java.util.Locale;
+
+public enum RoutingState {
+ STARTING,
+ STARTED,
+ STOPPING,
+ FAILED,
+ STOPPED;
+
+ public static RoutingState fromString(String value) {
+ return valueOf(value.toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java
new file mode 100644
index 0000000000000..c6f1ce7d71510
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class RoutingStateAndReason implements ToXContentObject, Writeable {
+
+ private static final ParseField REASON = new ParseField("reason");
+ private static final ParseField ROUTING_STATE = new ParseField("routing_state");
+
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_routing_state",
+ a -> new RoutingStateAndReason(RoutingState.fromString((String) a[0]), (String) a[1])
+ );
+ static {
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE);
+ PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
+ }
+
+ public static RoutingStateAndReason fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
+
+ private final String reason;
+ private final RoutingState state;
+
+ public RoutingStateAndReason(RoutingState state, String reason) {
+ this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE);
+ this.reason = reason;
+ }
+
+ public RoutingStateAndReason(StreamInput in) throws IOException {
+ this.state = in.readEnum(RoutingState.class);
+ this.reason = in.readOptionalString();
+ }
+
+ public String getReason() {
+ return reason;
+ }
+
+ public RoutingState getState() {
+ return state;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeEnum(state);
+ out.writeOptionalString(reason);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(ROUTING_STATE.getPreferredName(), state);
+ if (reason != null) {
+ builder.field(REASON.getPreferredName(), reason);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RoutingStateAndReason that = (RoutingStateAndReason) o;
+ return Objects.equals(reason, that.reason) && state == that.state;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(reason, state);
+ }
+
+ @Override
+ public String toString() {
+ return "RoutingStateAndReason{" + "reason='" + reason + '\'' + ", state=" + state + '}';
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
new file mode 100644
index 0000000000000..f15511097d8d0
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
@@ -0,0 +1,240 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.cluster.AbstractDiffable;
+import org.elasticsearch.cluster.Diffable;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+// TODO implement better diffable logic so that whole diff does not need to be serialized if only one part changes
+/**
+ * Trained model allocation object that contains allocation options and the allocation routing table
+ */
+public class TrainedModelAllocation extends AbstractDiffable
+ implements
+ Diffable,
+ ToXContentObject {
+
+ private static final ParseField ALLOCATION_STATE = new ParseField("allocation_state");
+ private static final ParseField ROUTING_TABLE = new ParseField("routing_table");
+ private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters");
+
+ @SuppressWarnings("unchecked")
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_allocation",
+ true,
+ a -> new TrainedModelAllocation(
+ (StartTrainedModelDeploymentAction.TaskParams) a[0],
+ (Map) a[1],
+ AllocationState.fromString((String)a[2])
+ )
+ );
+ static {
+ PARSER.declareObject(
+ ConstructingObjectParser.constructorArg(),
+ (p, c) -> StartTrainedModelDeploymentAction.TaskParams.fromXContent(p),
+ TASK_PARAMETERS
+ );
+ PARSER.declareObject(
+ ConstructingObjectParser.constructorArg(),
+ (p, c) -> p.map(LinkedHashMap::new, RoutingStateAndReason::fromXContent),
+ ROUTING_TABLE
+ );
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), ALLOCATION_STATE);
+ }
+
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+ private final Map nodeRoutingTable;
+ private final AllocationState allocationState;
+
+ public static TrainedModelAllocation fromXContent(XContentParser parser) throws IOException {
+ return PARSER.apply(parser, null);
+ }
+
+ TrainedModelAllocation(
+ StartTrainedModelDeploymentAction.TaskParams taskParams,
+ Map nodeRoutingTable,
+ AllocationState allocationState
+ ) {
+ this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS);
+ this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE);
+ this.allocationState = ExceptionsHelper.requireNonNull(allocationState, ALLOCATION_STATE);
+ }
+
+ public TrainedModelAllocation(StreamInput in) throws IOException {
+ this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
+ this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new);
+ this.allocationState = in.readEnum(AllocationState.class);
+ }
+
+ public boolean isRoutedToNode(String nodeId) {
+ return nodeRoutingTable.containsKey(nodeId);
+ }
+
+ public Map getNodeRoutingTable() {
+ return Collections.unmodifiableMap(nodeRoutingTable);
+ }
+
+ public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
+ return taskParams;
+ }
+
+ public AllocationState getAllocationState() {
+ return allocationState;
+ }
+
+ public String[] getStartedNodes() {
+ return nodeRoutingTable
+ .entrySet()
+ .stream()
+ .filter(entry -> RoutingState.STARTED.equals(entry.getValue().getState()))
+ .map(Map.Entry::getKey)
+ .toArray(String[]::new);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ TrainedModelAllocation that = (TrainedModelAllocation) o;
+ return Objects.equals(nodeRoutingTable, that.nodeRoutingTable)
+ && Objects.equals(taskParams, that.taskParams)
+ && Objects.equals(allocationState, that.allocationState);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(nodeRoutingTable, taskParams, allocationState);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(TASK_PARAMETERS.getPreferredName(), taskParams);
+ builder.field(ROUTING_TABLE.getPreferredName(), nodeRoutingTable);
+ builder.field(ALLOCATION_STATE.getPreferredName(), allocationState);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ taskParams.writeTo(out);
+ out.writeMap(nodeRoutingTable, StreamOutput::writeString, (o, w) -> w.writeTo(o));
+ out.writeEnum(allocationState);
+ }
+
+ public static class Builder {
+ private final Map nodeRoutingTable;
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+ private AllocationState allocationState;
+ private boolean isChanged;
+
+ public static Builder fromAllocation(TrainedModelAllocation allocation) {
+ return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState);
+ }
+
+ public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ return new Builder(taskParams);
+ }
+
+ private Builder(
+ StartTrainedModelDeploymentAction.TaskParams taskParams,
+ Map nodeRoutingTable,
+ AllocationState allocationState
+ ) {
+ this.taskParams = taskParams;
+ this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable);
+ this.allocationState = allocationState;
+ }
+
+ private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ this.nodeRoutingTable = new LinkedHashMap<>();
+ this.taskParams = taskParams;
+ this.allocationState = AllocationState.STARTED;
+ }
+
+ public Builder addNewRoutingEntry(String nodeId) {
+ if (nodeRoutingTable.containsKey(nodeId)) {
+ throw new ResourceAlreadyExistsException(
+ "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId()
+ );
+ }
+ isChanged = true;
+ nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.STARTING, ""));
+ return this;
+ }
+
+ public Builder addNewFailedRoutingEntry(String nodeId, String reason) {
+ if (nodeRoutingTable.containsKey(nodeId)) {
+ throw new ResourceAlreadyExistsException(
+ "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId()
+ );
+ }
+ isChanged = true;
+ nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, reason));
+ return this;
+ }
+
+ public Builder updateExistingRoutingEntry(String nodeId, RoutingStateAndReason state) {
+ RoutingStateAndReason stateAndReason = nodeRoutingTable.get(nodeId);
+ if (stateAndReason == null) {
+ throw new ResourceNotFoundException(
+ "routing entry for node [{}] for model [{}] does not exist", nodeId, taskParams.getModelId()
+ );
+ }
+ if (stateAndReason.equals(state)) {
+ return this;
+ }
+ nodeRoutingTable.put(nodeId, state);
+ isChanged = true;
+ return this;
+ }
+
+ public Builder removeRoutingEntry(String nodeId) {
+ if (nodeRoutingTable.remove(nodeId) != null) {
+ isChanged = true;
+ }
+ return this;
+ }
+
+ public Builder stopAllocation() {
+ if (allocationState.equals(AllocationState.STOPPING)) {
+ return this;
+ }
+ isChanged = true;
+ allocationState = AllocationState.STOPPING;
+ return this;
+ }
+
+ public boolean isChanged() {
+ return isChanged;
+ }
+
+ public TrainedModelAllocation build() {
+ return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
index 98be0e92e85cd..85acf4df9d78a 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
@@ -47,7 +47,7 @@ public static IndexLocation fromXContentLenient(XContentParser parser) throws IO
private final String modelId;
private final String indexName;
- IndexLocation(String modelId, String indexName) {
+ public IndexLocation(String modelId, String indexName) {
this.modelId = Objects.requireNonNull(modelId);
this.indexName = Objects.requireNonNull(indexName);
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java
new file mode 100644
index 0000000000000..ae130f1288653
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Request;
+
+public class CreateTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(
+ new StartTrainedModelDeploymentAction.TaskParams(
+ randomAlphaOfLength(10),
+ randomAlphaOfLength(10),
+ randomNonNegativeLong()
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java
new file mode 100644
index 0000000000000..980f5ef050559
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocationTests;
+
+import java.io.IOException;
+
+public class CreateTrainedModelAllocationActionResponseTests extends AbstractSerializingTestCase {
+
+ @Override
+ protected Response createTestInstance() {
+ return new Response(TrainedModelAllocationTests.randomInstance());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Response::new;
+ }
+
+ @Override
+ protected Response doParseInstance(XContentParser parser) throws IOException {
+ return Response.fromXContent(parser);
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java
new file mode 100644
index 0000000000000..933c60dcbf419
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java
@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction.Request;
+
+public class DeleteTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(randomAlphaOfLength(10));
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java
new file mode 100644
index 0000000000000..5b7104934f121
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java
@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction.Request;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReasonTests;
+
+public class UpdateTrainedModelAllocationStateActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(randomAlphaOfLength(10), randomAlphaOfLength(10), RoutingStateAndReasonTests.randomInstance());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java
new file mode 100644
index 0000000000000..42a62e35fe80d
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class AllocationStateTests extends ESTestCase {
+
+ public void testToAndFromString() {
+ for (AllocationState state : AllocationState.values()) {
+ String value = state.toString();
+ assertThat(AllocationState.fromString(value), equalTo(state));
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java
new file mode 100644
index 0000000000000..438372248cee3
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+public class RoutingStateAndReasonTests extends AbstractSerializingTestCase {
+
+ public static RoutingStateAndReason randomInstance() {
+ return new RoutingStateAndReason(randomFrom(RoutingState.values()), randomBoolean() ? null : randomAlphaOfLength(10));
+ }
+
+ @Override
+ protected RoutingStateAndReason doParseInstance(XContentParser parser) throws IOException {
+ return RoutingStateAndReason.fromXContent(parser);
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return RoutingStateAndReason::new;
+ }
+
+ @Override
+ protected RoutingStateAndReason createTestInstance() {
+ return randomInstance();
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java
new file mode 100644
index 0000000000000..883339250ce24
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class RoutingStateTests extends ESTestCase {
+
+ public void testToAndFromString() {
+ for (RoutingState state : RoutingState.values()) {
+ String value = state.toString();
+ assertThat(RoutingState.fromString(value), equalTo(state));
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java
new file mode 100644
index 0000000000000..903c897bfd24c
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
+import static org.hamcrest.Matchers.is;
+
+public class TrainedModelAllocationTests extends AbstractSerializingTestCase {
+
+ public static TrainedModelAllocation randomInstance() {
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
+ new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong())
+ );
+ List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
+ for (String node : nodes) {
+ if (randomBoolean()) {
+ builder.addNewFailedRoutingEntry(node, randomAlphaOfLength(10));
+ } else {
+ builder.addNewRoutingEntry(node);
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TrainedModelAllocation doParseInstance(XContentParser parser) throws IOException {
+ return TrainedModelAllocation.fromXContent(parser);
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return TrainedModelAllocation::new;
+ }
+
+ @Override
+ protected TrainedModelAllocation createTestInstance() {
+ return randomInstance();
+ }
+
+ public void testBuilderChanged() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ assertThat(builder.isChanged(), is(false));
+ String addingNode = "foo";
+
+ assertUnchanged(builder, b -> b.removeRoutingEntry(addingNode));
+
+ if (randomBoolean()) {
+ builder.addNewRoutingEntry(addingNode);
+ } else {
+ builder.addNewFailedRoutingEntry(addingNode, "test failed");
+ }
+ assertThat(builder.isChanged(), is(true));
+
+ TrainedModelAllocation.Builder builderWithNode = TrainedModelAllocation.Builder.fromAllocation(builder.build());
+ assertThat(builderWithNode.isChanged(), is(false));
+
+ builderWithNode.removeRoutingEntry(addingNode);
+ assertThat(builderWithNode.isChanged(), is(true));
+ }
+
+ public void testBuilderAddingExistingRoute() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ String addingNode = "new-node";
+ if (randomBoolean()) {
+ builder.addNewRoutingEntry(addingNode);
+ } else {
+ builder.addNewFailedRoutingEntry(addingNode, "test failed");
+ }
+ expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewFailedRoutingEntry("new-node", "anything"));
+ expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewRoutingEntry("new-node"));
+ }
+
+ public void testBuilderUpdatingMissingRoute() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ String addingNode = "new-node";
+ expectThrows(
+ ResourceNotFoundException.class,
+ () -> builder.updateExistingRoutingEntry(addingNode, RoutingStateAndReasonTests.randomInstance())
+ );
+ }
+
+ public void testGetStartedNodes() {
+ String startedNode1 = "started-node-1";
+ String startedNode2 = "started-node-2";
+ String nodeInAnotherState1 = "another-state-node-1";
+ String nodeInAnotherState2 = "another-state-node-2";
+ TrainedModelAllocation allocation = TrainedModelAllocation.Builder.empty(
+ new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong())
+ )
+ .addNewRoutingEntry(startedNode1)
+ .addNewRoutingEntry(startedNode2)
+ .addNewRoutingEntry(nodeInAnotherState1)
+ .addNewRoutingEntry(nodeInAnotherState2)
+ .updateExistingRoutingEntry(startedNode1, new RoutingStateAndReason(RoutingState.STARTED, ""))
+ .updateExistingRoutingEntry(startedNode2, new RoutingStateAndReason(RoutingState.STARTED, ""))
+ .updateExistingRoutingEntry(
+ nodeInAnotherState1,
+ new RoutingStateAndReason(
+ randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING),
+ randomAlphaOfLength(10)
+ )
+ )
+ .updateExistingRoutingEntry(
+ nodeInAnotherState2,
+ new RoutingStateAndReason(
+ randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING),
+ randomAlphaOfLength(10)
+ )
+ )
+ .build();
+ assertThat(allocation.getStartedNodes(), arrayContainingInAnyOrder(startedNode1, startedNode2));
+ }
+
+ private static void assertUnchanged(
+ TrainedModelAllocation.Builder builder,
+ Function function
+ ) {
+ function.apply(builder);
+ assertThat(builder.isChanged(), is(false));
+ }
+
+}
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
index a3ed3d0c51634..5b366c7e83b38 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
@@ -15,6 +15,8 @@
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
+import org.junit.After;
+import org.junit.Before;
import java.io.IOException;
import java.util.Base64;
@@ -53,6 +55,32 @@ protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
}
+ @Before
+ public void setLogging() throws IOException {
+ Request loggingSettings = new Request("PUT", "_cluster/settings");
+ loggingSettings.setJsonEntity("" +
+ "{" +
+ "\"transient\" : {\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" +
+ " }" +
+ "}");
+ client().performRequest(loggingSettings);
+ }
+
+ @After
+ public void unsetLogging() throws IOException {
+ Request loggingSettings = new Request("PUT", "_cluster/settings");
+ loggingSettings.setJsonEntity("" +
+ "{" +
+ "\"transient\" : {\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" :null,\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null\n" +
+ " }" +
+ "}");
+ client().performRequest(loggingSettings);
+ }
+
private static final String MODEL_INDEX = "model_store";
private static final String MODEL_ID ="simple_model_to_evaluate";
private static final String BASE_64_ENCODED_MODEL =
@@ -92,8 +120,11 @@ public void testEvaluate() throws IOException {
createTrainedModel();
startDeployment();
try {
- Response inference = infer("my words");
- assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}"));
+ // Adding multiple inference calls to verify different calls get routed to separate nodes
+ for (int i = 0; i < 10; i++) {
+ Response inference = infer("my words");
+ assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}"));
+ }
} finally {
stopDeployment();
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
index 06439fb04b107..1c4c3572e992c 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
@@ -89,6 +89,7 @@
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@@ -100,6 +101,7 @@
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -159,6 +161,7 @@
import org.elasticsearch.xpack.core.ml.action.UpdateJobAction;
import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
@@ -177,6 +180,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.template.TemplateUtils;
import org.elasticsearch.xpack.ml.action.TransportCloseJobAction;
+import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarEventAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteDataFrameAnalyticsAction;
@@ -188,6 +192,7 @@
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction;
+import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAllocationAction;
import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction;
import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
@@ -247,6 +252,7 @@
import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction;
import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction;
+import org.elasticsearch.xpack.ml.action.TransportUpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction;
import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction;
@@ -275,6 +281,9 @@
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -284,6 +293,7 @@
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
+import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilter;
import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilterFactory;
@@ -855,6 +865,18 @@ public Collection