diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml index 5cf789707c58c..3f3eb5218afed 100644 --- a/.idea/inspectionProfiles/Project_Default.xml +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -5,5 +5,6 @@ + \ No newline at end of file diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java new file mode 100644 index 0000000000000..ccf567e76a559 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class CreateTrainedModelAllocationAction extends ActionType { + public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction(); + public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create"; + + private CreateTrainedModelAllocationAction() { + super(NAME, CreateTrainedModelAllocationAction.Response::new); + } + + public static class Request extends MasterNodeRequest { + private final StartTrainedModelDeploymentAction.TaskParams taskParams; + + public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) { + this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams"); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + taskParams.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(taskParams, request.taskParams); + } + + @Override + public int hashCode() { + return Objects.hash(taskParams); + } + + public StartTrainedModelDeploymentAction.TaskParams getTaskParams() { + return taskParams; + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private static final ParseField ALLOCATION = new ParseField("allocation"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "create_trained_model_allocation_response", + a -> new Response((TrainedModelAllocation) a[0]) + ); + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION); + } + static Response fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final TrainedModelAllocation trainedModelAllocation; + + public Response(TrainedModelAllocation trainedModelAllocation) { + this.trainedModelAllocation = trainedModelAllocation; + } + + public Response(StreamInput in) throws IOException { + super(in); + this.trainedModelAllocation = new TrainedModelAllocation(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + trainedModelAllocation.writeTo(out); + } + + public TrainedModelAllocation getTrainedModelAllocation() { + return trainedModelAllocation; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("allocation", trainedModelAllocation); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(trainedModelAllocation, response.trainedModelAllocation); + } + + @Override + public int hashCode() { + return Objects.hash(trainedModelAllocation); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java new file mode 100644 index 0000000000000..589ae631dece8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteTrainedModelAllocationAction extends ActionType { + public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction(); + public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete"; + + private DeleteTrainedModelAllocationAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends MasterNodeRequest { + private final String modelId; + + public Request(String modelId) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelId, request.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 6e0c4d517d1b7..fb41c1d92a5ec 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -11,11 +11,15 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.TimeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -31,7 +35,7 @@ import java.util.Objects; import java.util.concurrent.TimeUnit; -public class StartTrainedModelDeploymentAction extends ActionType { +public class StartTrainedModelDeploymentAction extends ActionType { public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction(); public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start"; @@ -39,7 +43,7 @@ public class StartTrainedModelDeploymentAction extends ActionType implements ToXContentObject { @@ -120,9 +124,29 @@ public String toString() { public static class TaskParams implements PersistentTaskParams, MlTaskParams { - public static final Version VERSION_INTRODUCED = Version.V_8_0_0; + // TODO add support for other roles? If so, it may have to be an instance method... + // NOTE, whatever determines allocation should not be dynamically set on the node + // Otherwise allocation logic might fail + public static boolean mayAllocateToNode(DiscoveryNode node) { + return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE); + } + public static final Version VERSION_INTRODUCED = Version.V_8_0_0; private static final ParseField MODEL_BYTES = new ParseField("model_bytes"); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_deployment_params", + true, + a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID); + PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES); + } + + public static TaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } /** * This has been found to be approximately 300MB on linux by manual testing. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java new file mode 100644 index 0000000000000..e1ae0e6c2258c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class UpdateTrainedModelAllocationStateAction extends ActionType { + public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction(); + public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update"; + + private UpdateTrainedModelAllocationStateAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends MasterNodeRequest { + private final String nodeId; + private final String modelId; + private final RoutingStateAndReason routingState; + + public Request(String nodeId, String modelId, RoutingStateAndReason routingState) { + this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id"); + this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); + this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state"); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.nodeId = in.readString(); + this.modelId = in.readString(); + this.routingState = new RoutingStateAndReason(in); + } + + public String getNodeId() { + return nodeId; + } + + public String getModelId() { + return modelId; + } + + public RoutingStateAndReason getRoutingState() { + return routingState; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(nodeId); + out.writeString(modelId); + routingState.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(nodeId, request.nodeId) + && Objects.equals(modelId, request.modelId) + && Objects.equals(routingState, request.routingState); + } + + @Override + public int hashCode() { + return Objects.hash(nodeId, modelId, routingState); + } + + @Override + public String toString() { + return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}'; + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java new file mode 100644 index 0000000000000..c9ef574d39f8f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import java.util.Locale; + +public enum AllocationState { + STARTED, + STOPPING; + + public static AllocationState fromString(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java new file mode 100644 index 0000000000000..865a490cbf64a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import java.util.Locale; + +public enum RoutingState { + STARTING, + STARTED, + STOPPING, + FAILED, + STOPPED; + + public static RoutingState fromString(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java new file mode 100644 index 0000000000000..c6f1ce7d71510 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class RoutingStateAndReason implements ToXContentObject, Writeable { + + private static final ParseField REASON = new ParseField("reason"); + private static final ParseField ROUTING_STATE = new ParseField("routing_state"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_routing_state", + a -> new RoutingStateAndReason(RoutingState.fromString((String) a[0]), (String) a[1]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); + } + + public static RoutingStateAndReason fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String reason; + private final RoutingState state; + + public RoutingStateAndReason(RoutingState state, String reason) { + this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE); + this.reason = reason; + } + + public RoutingStateAndReason(StreamInput in) throws IOException { + this.state = in.readEnum(RoutingState.class); + this.reason = in.readOptionalString(); + } + + public String getReason() { + return reason; + } + + public RoutingState getState() { + return state; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(state); + out.writeOptionalString(reason); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ROUTING_STATE.getPreferredName(), state); + if (reason != null) { + builder.field(REASON.getPreferredName(), reason); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RoutingStateAndReason that = (RoutingStateAndReason) o; + return Objects.equals(reason, that.reason) && state == that.state; + } + + @Override + public int hashCode() { + return Objects.hash(reason, state); + } + + @Override + public String toString() { + return "RoutingStateAndReason{" + "reason='" + reason + '\'' + ", state=" + state + '}'; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java new file mode 100644 index 0000000000000..f15511097d8d0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java @@ -0,0 +1,240 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.cluster.AbstractDiffable; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +// TODO implement better diffable logic so that whole diff does not need to be serialized if only one part changes +/** + * Trained model allocation object that contains allocation options and the allocation routing table + */ +public class TrainedModelAllocation extends AbstractDiffable + implements + Diffable, + ToXContentObject { + + private static final ParseField ALLOCATION_STATE = new ParseField("allocation_state"); + private static final ParseField ROUTING_TABLE = new ParseField("routing_table"); + private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_allocation", + true, + a -> new TrainedModelAllocation( + (StartTrainedModelDeploymentAction.TaskParams) a[0], + (Map) a[1], + AllocationState.fromString((String)a[2]) + ) + ); + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> StartTrainedModelDeploymentAction.TaskParams.fromXContent(p), + TASK_PARAMETERS + ); + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> p.map(LinkedHashMap::new, RoutingStateAndReason::fromXContent), + ROUTING_TABLE + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), ALLOCATION_STATE); + } + + private final StartTrainedModelDeploymentAction.TaskParams taskParams; + private final Map nodeRoutingTable; + private final AllocationState allocationState; + + public static TrainedModelAllocation fromXContent(XContentParser parser) throws IOException { + return PARSER.apply(parser, null); + } + + TrainedModelAllocation( + StartTrainedModelDeploymentAction.TaskParams taskParams, + Map nodeRoutingTable, + AllocationState allocationState + ) { + this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS); + this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE); + this.allocationState = ExceptionsHelper.requireNonNull(allocationState, ALLOCATION_STATE); + } + + public TrainedModelAllocation(StreamInput in) throws IOException { + this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); + this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new); + this.allocationState = in.readEnum(AllocationState.class); + } + + public boolean isRoutedToNode(String nodeId) { + return nodeRoutingTable.containsKey(nodeId); + } + + public Map getNodeRoutingTable() { + return Collections.unmodifiableMap(nodeRoutingTable); + } + + public StartTrainedModelDeploymentAction.TaskParams getTaskParams() { + return taskParams; + } + + public AllocationState getAllocationState() { + return allocationState; + } + + public String[] getStartedNodes() { + return nodeRoutingTable + .entrySet() + .stream() + .filter(entry -> RoutingState.STARTED.equals(entry.getValue().getState())) + .map(Map.Entry::getKey) + .toArray(String[]::new); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelAllocation that = (TrainedModelAllocation) o; + return Objects.equals(nodeRoutingTable, that.nodeRoutingTable) + && Objects.equals(taskParams, that.taskParams) + && Objects.equals(allocationState, that.allocationState); + } + + @Override + public int hashCode() { + return Objects.hash(nodeRoutingTable, taskParams, allocationState); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_PARAMETERS.getPreferredName(), taskParams); + builder.field(ROUTING_TABLE.getPreferredName(), nodeRoutingTable); + builder.field(ALLOCATION_STATE.getPreferredName(), allocationState); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + taskParams.writeTo(out); + out.writeMap(nodeRoutingTable, StreamOutput::writeString, (o, w) -> w.writeTo(o)); + out.writeEnum(allocationState); + } + + public static class Builder { + private final Map nodeRoutingTable; + private final StartTrainedModelDeploymentAction.TaskParams taskParams; + private AllocationState allocationState; + private boolean isChanged; + + public static Builder fromAllocation(TrainedModelAllocation allocation) { + return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState); + } + + public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) { + return new Builder(taskParams); + } + + private Builder( + StartTrainedModelDeploymentAction.TaskParams taskParams, + Map nodeRoutingTable, + AllocationState allocationState + ) { + this.taskParams = taskParams; + this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable); + this.allocationState = allocationState; + } + + private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) { + this.nodeRoutingTable = new LinkedHashMap<>(); + this.taskParams = taskParams; + this.allocationState = AllocationState.STARTED; + } + + public Builder addNewRoutingEntry(String nodeId) { + if (nodeRoutingTable.containsKey(nodeId)) { + throw new ResourceAlreadyExistsException( + "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId() + ); + } + isChanged = true; + nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.STARTING, "")); + return this; + } + + public Builder addNewFailedRoutingEntry(String nodeId, String reason) { + if (nodeRoutingTable.containsKey(nodeId)) { + throw new ResourceAlreadyExistsException( + "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId() + ); + } + isChanged = true; + nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, reason)); + return this; + } + + public Builder updateExistingRoutingEntry(String nodeId, RoutingStateAndReason state) { + RoutingStateAndReason stateAndReason = nodeRoutingTable.get(nodeId); + if (stateAndReason == null) { + throw new ResourceNotFoundException( + "routing entry for node [{}] for model [{}] does not exist", nodeId, taskParams.getModelId() + ); + } + if (stateAndReason.equals(state)) { + return this; + } + nodeRoutingTable.put(nodeId, state); + isChanged = true; + return this; + } + + public Builder removeRoutingEntry(String nodeId) { + if (nodeRoutingTable.remove(nodeId) != null) { + isChanged = true; + } + return this; + } + + public Builder stopAllocation() { + if (allocationState.equals(AllocationState.STOPPING)) { + return this; + } + isChanged = true; + allocationState = AllocationState.STOPPING; + return this; + } + + public boolean isChanged() { + return isChanged; + } + + public TrainedModelAllocation build() { + return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java index 98be0e92e85cd..85acf4df9d78a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java @@ -47,7 +47,7 @@ public static IndexLocation fromXContentLenient(XContentParser parser) throws IO private final String modelId; private final String indexName; - IndexLocation(String modelId, String indexName) { + public IndexLocation(String modelId, String indexName) { this.modelId = Objects.requireNonNull(modelId); this.indexName = Objects.requireNonNull(indexName); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java new file mode 100644 index 0000000000000..ae130f1288653 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Request; + +public class CreateTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request( + new StartTrainedModelDeploymentAction.TaskParams( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomNonNegativeLong() + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java new file mode 100644 index 0000000000000..980f5ef050559 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Response; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocationTests; + +import java.io.IOException; + +public class CreateTrainedModelAllocationActionResponseTests extends AbstractSerializingTestCase { + + @Override + protected Response createTestInstance() { + return new Response(TrainedModelAllocationTests.randomInstance()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } + + @Override + protected Response doParseInstance(XContentParser parser) throws IOException { + return Response.fromXContent(parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java new file mode 100644 index 0000000000000..933c60dcbf419 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction.Request; + +public class DeleteTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLength(10)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java new file mode 100644 index 0000000000000..5b7104934f121 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction.Request; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReasonTests; + +public class UpdateTrainedModelAllocationStateActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLength(10), randomAlphaOfLength(10), RoutingStateAndReasonTests.randomInstance()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java new file mode 100644 index 0000000000000..42a62e35fe80d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class AllocationStateTests extends ESTestCase { + + public void testToAndFromString() { + for (AllocationState state : AllocationState.values()) { + String value = state.toString(); + assertThat(AllocationState.fromString(value), equalTo(state)); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java new file mode 100644 index 0000000000000..438372248cee3 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class RoutingStateAndReasonTests extends AbstractSerializingTestCase { + + public static RoutingStateAndReason randomInstance() { + return new RoutingStateAndReason(randomFrom(RoutingState.values()), randomBoolean() ? null : randomAlphaOfLength(10)); + } + + @Override + protected RoutingStateAndReason doParseInstance(XContentParser parser) throws IOException { + return RoutingStateAndReason.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return RoutingStateAndReason::new; + } + + @Override + protected RoutingStateAndReason createTestInstance() { + return randomInstance(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java new file mode 100644 index 0000000000000..883339250ce24 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class RoutingStateTests extends ESTestCase { + + public void testToAndFromString() { + for (RoutingState state : RoutingState.values()) { + String value = state.toString(); + assertThat(RoutingState.fromString(value), equalTo(state)); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java new file mode 100644 index 0000000000000..903c897bfd24c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.allocation; + +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.arrayContainingInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class TrainedModelAllocationTests extends AbstractSerializingTestCase { + + public static TrainedModelAllocation randomInstance() { + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong()) + ); + List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); + for (String node : nodes) { + if (randomBoolean()) { + builder.addNewFailedRoutingEntry(node, randomAlphaOfLength(10)); + } else { + builder.addNewRoutingEntry(node); + } + } + return builder.build(); + } + + @Override + protected TrainedModelAllocation doParseInstance(XContentParser parser) throws IOException { + return TrainedModelAllocation.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelAllocation::new; + } + + @Override + protected TrainedModelAllocation createTestInstance() { + return randomInstance(); + } + + public void testBuilderChanged() { + TrainedModelAllocation original = randomInstance(); + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original); + assertThat(builder.isChanged(), is(false)); + String addingNode = "foo"; + + assertUnchanged(builder, b -> b.removeRoutingEntry(addingNode)); + + if (randomBoolean()) { + builder.addNewRoutingEntry(addingNode); + } else { + builder.addNewFailedRoutingEntry(addingNode, "test failed"); + } + assertThat(builder.isChanged(), is(true)); + + TrainedModelAllocation.Builder builderWithNode = TrainedModelAllocation.Builder.fromAllocation(builder.build()); + assertThat(builderWithNode.isChanged(), is(false)); + + builderWithNode.removeRoutingEntry(addingNode); + assertThat(builderWithNode.isChanged(), is(true)); + } + + public void testBuilderAddingExistingRoute() { + TrainedModelAllocation original = randomInstance(); + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original); + String addingNode = "new-node"; + if (randomBoolean()) { + builder.addNewRoutingEntry(addingNode); + } else { + builder.addNewFailedRoutingEntry(addingNode, "test failed"); + } + expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewFailedRoutingEntry("new-node", "anything")); + expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewRoutingEntry("new-node")); + } + + public void testBuilderUpdatingMissingRoute() { + TrainedModelAllocation original = randomInstance(); + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original); + String addingNode = "new-node"; + expectThrows( + ResourceNotFoundException.class, + () -> builder.updateExistingRoutingEntry(addingNode, RoutingStateAndReasonTests.randomInstance()) + ); + } + + public void testGetStartedNodes() { + String startedNode1 = "started-node-1"; + String startedNode2 = "started-node-2"; + String nodeInAnotherState1 = "another-state-node-1"; + String nodeInAnotherState2 = "another-state-node-2"; + TrainedModelAllocation allocation = TrainedModelAllocation.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong()) + ) + .addNewRoutingEntry(startedNode1) + .addNewRoutingEntry(startedNode2) + .addNewRoutingEntry(nodeInAnotherState1) + .addNewRoutingEntry(nodeInAnotherState2) + .updateExistingRoutingEntry(startedNode1, new RoutingStateAndReason(RoutingState.STARTED, "")) + .updateExistingRoutingEntry(startedNode2, new RoutingStateAndReason(RoutingState.STARTED, "")) + .updateExistingRoutingEntry( + nodeInAnotherState1, + new RoutingStateAndReason( + randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING), + randomAlphaOfLength(10) + ) + ) + .updateExistingRoutingEntry( + nodeInAnotherState2, + new RoutingStateAndReason( + randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING), + randomAlphaOfLength(10) + ) + ) + .build(); + assertThat(allocation.getStartedNodes(), arrayContainingInAnyOrder(startedNode1, startedNode2)); + } + + private static void assertUnchanged( + TrainedModelAllocation.Builder builder, + Function function + ) { + function.apply(builder); + assertThat(builder.isChanged(), is(false)); + } + +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index a3ed3d0c51634..5b366c7e83b38 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -15,6 +15,8 @@ import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.Base64; @@ -53,6 +55,32 @@ protected Settings restClientSettings() { return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build(); } + @Before + public void setLogging() throws IOException { + Request loggingSettings = new Request("PUT", "_cluster/settings"); + loggingSettings.setJsonEntity("" + + "{" + + "\"transient\" : {\n" + + " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" + + " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" + + " }" + + "}"); + client().performRequest(loggingSettings); + } + + @After + public void unsetLogging() throws IOException { + Request loggingSettings = new Request("PUT", "_cluster/settings"); + loggingSettings.setJsonEntity("" + + "{" + + "\"transient\" : {\n" + + " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" :null,\n" + + " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null\n" + + " }" + + "}"); + client().performRequest(loggingSettings); + } + private static final String MODEL_INDEX = "model_store"; private static final String MODEL_ID ="simple_model_to_evaluate"; private static final String BASE_64_ENCODED_MODEL = @@ -92,8 +120,11 @@ public void testEvaluate() throws IOException { createTrainedModel(); startDeployment(); try { - Response inference = infer("my words"); - assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}")); + // Adding multiple inference calls to verify different calls get routed to separate nodes + for (int i = 0; i < 10; i++) { + Response inference = infer("my words"); + assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}")); + } } finally { stopDeployment(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 06439fb04b107..1c4c3572e992c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -89,6 +89,7 @@ import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.CloseJobAction; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; @@ -100,6 +101,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -159,6 +161,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; @@ -177,6 +180,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.template.TemplateUtils; import org.elasticsearch.xpack.ml.action.TransportCloseJobAction; +import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAllocationAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarEventAction; import org.elasticsearch.xpack.ml.action.TransportDeleteDataFrameAnalyticsAction; @@ -188,6 +192,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAllocationAction; import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction; import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; @@ -247,6 +252,7 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; +import org.elasticsearch.xpack.ml.action.TransportUpdateTrainedModelAllocationStateAction; import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; @@ -275,6 +281,9 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService; import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; @@ -284,6 +293,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; +import org.elasticsearch.xpack.ml.job.NodeLoadDetector; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilter; import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilterFactory; @@ -855,6 +865,18 @@ public Collection createComponents(Client client, ClusterService cluster // Perform node startup operations nativeStorageProvider.cleanupLocalTmpStorageInCaseOfUncleanShutdown(); + // allocation service objects + final TrainedModelAllocationService trainedModelAllocationService = new TrainedModelAllocationService( + client, + clusterService, + threadPool + ); + final TrainedModelAllocationClusterService trainedModelAllocationClusterService = new TrainedModelAllocationClusterService( + settings, + clusterService, + new NodeLoadDetector(memoryTracker) + ); + mlAutoscalingDeciderService.set(new MlAutoscalingDeciderService(memoryTracker, settings, clusterService)); return Arrays.asList( @@ -882,7 +904,10 @@ public Collection createComponents(Client client, ClusterService cluster dataFrameAnalyticsConfigProvider, nativeStorageProvider, modelLoadingService, - trainedModelProvider + trainedModelProvider, + trainedModelAllocationService, + trainedModelAllocationClusterService, + deploymentManager.get() ); } @@ -917,12 +942,7 @@ public List> getPersistentTasksExecutor(ClusterServic autodetectProcessManager.get(), memoryTracker.get(), expressionResolver, - client), - new TransportStartTrainedModelDeploymentAction.TaskExecutor(settings, - clusterService, - expressionResolver, - memoryTracker.get(), - deploymentManager.get()) + client) ); } @@ -1094,6 +1114,12 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class), new ActionHandler<>(InferTrainedModelDeploymentAction.INSTANCE, TransportInferTrainedModelDeploymentAction.class), new ActionHandler<>(GetDatafeedRunningStateAction.INSTANCE, TransportGetDatafeedRunningStateAction.class), + new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class), + new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class), + new ActionHandler<>( + UpdateTrainedModelAllocationStateAction.INSTANCE, + TransportUpdateTrainedModelAllocationStateAction.class + ), usageAction, infoAction); } @@ -1217,6 +1243,13 @@ public List getNamedXContent() { ModelAliasMetadata::fromXContent ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField(TrainedModelAllocationMetadata.NAME), + TrainedModelAllocationMetadata::fromXContent + ) + ); namedXContent.addAll(new CorrelationNamedContentProvider().getNamedXContentParsers()); return namedXContent; } @@ -1230,6 +1263,20 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + Metadata.Custom.class, + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + NamedDiff.class, + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata::readDiffFrom + ) + ); // Persistent tasks params namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATAFEED_TASK_NAME, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java new file mode 100644 index 0000000000000..45861de90a87f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Request; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Response; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; + +public class TransportCreateTrainedModelAllocationAction extends TransportMasterNodeAction { + + private final TrainedModelAllocationClusterService trainedModelAllocationClusterService; + + @Inject + public TransportCreateTrainedModelAllocationAction( + TrainedModelAllocationClusterService trainedModelAllocationClusterService, + TrainedModelAllocationService trainedModelAllocationService, + DeploymentManager deploymentManager, + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + CreateTrainedModelAllocationAction.NAME, + false, + transportService, + clusterService, + threadPool, + actionFilters, + Request::new, + indexNameExpressionResolver, + Response::new, + ThreadPool.Names.SAME + ); + this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; + // Here we create our singleton for the node service + clusterService.addListener( + new TrainedModelAllocationNodeService( + trainedModelAllocationService, + clusterService, + deploymentManager, + transportService.getTaskManager(), + threadPool + ) + ); + } + + @Override + protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { + trainedModelAllocationClusterService.createNewModelAllocation( + request.getTaskParams(), + ActionListener.wrap(trainedModelAllocation -> listener.onResponse(new Response(trainedModelAllocation)), listener::onFailure) + ); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAllocationAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAllocationAction.java new file mode 100644 index 0000000000000..d1c1cd0e42a69 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAllocationAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction.Request; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; + +public class TransportDeleteTrainedModelAllocationAction extends AcknowledgedTransportMasterNodeAction { + + private final TrainedModelAllocationClusterService trainedModelAllocationClusterService; + + @Inject + public TransportDeleteTrainedModelAllocationAction( + TrainedModelAllocationClusterService trainedModelAllocationClusterService, + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + DeleteTrainedModelAllocationAction.NAME, + false, + transportService, + clusterService, + threadPool, + actionFilters, + Request::new, + indexNameExpressionResolver, + ThreadPool.Names.SAME + ); + this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; + } + + @Override + protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) + throws Exception { + trainedModelAllocationClusterService.removeModelAllocation(request.getModelId(), listener); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index be5a218985a13..a9f0d409c2d08 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -13,14 +13,15 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import java.util.List; @@ -42,15 +43,24 @@ protected void doExecute(Task task, InferTrainedModelDeploymentAction.Request re String deploymentId = request.getDeploymentId(); // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the // node running the job task. - PersistentTasksCustomMetadata tasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE); - PersistentTasksCustomMetadata.PersistentTask deploymentTask = MlTasks.getTrainedModelDeploymentTask(deploymentId, tasks); - if (deploymentTask == null || deploymentTask.isAssigned() == false) { + TrainedModelAllocation allocation = TrainedModelAllocationMetadata + .allocationForModelId(clusterService.state(), request.getDeploymentId()) + .orElse(null); + if (allocation == null) { String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not started"; listener.onFailure(ExceptionsHelper.conflictStatusException(message)); - } else { - request.setNodes(deploymentTask.getExecutorNode()); - super.doExecute(task, request, listener); + return; + } + String[] randomRunningNode = allocation.getStartedNodes(); + if (randomRunningNode.length == 0) { + String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not yet running on any node"; + listener.onFailure(ExceptionsHelper.conflictStatusException(message)); + return; } + // TODO Do better routing for inference calls + int nodeIndex = Randomness.get().nextInt(randomRunningNode.length); + request.setNodes(randomRunningNode[nodeIndex]); + super.doExecute(task, request, listener); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 270f886d3cb8c..cf7bd0ecdeac2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -20,58 +20,49 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.core.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; -import org.elasticsearch.persistent.AllocatedPersistentTask; -import org.elasticsearch.persistent.PersistentTaskParams; -import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; -import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; -import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; -import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer; -import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; -import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; -import java.util.Collection; import java.util.Map; import java.util.Objects; -import java.util.Optional; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; import java.util.function.Predicate; public class TransportStartTrainedModelDeploymentAction - extends TransportMasterNodeAction { + extends TransportMasterNodeAction { private static final Logger logger = LogManager.getLogger(TransportStartTrainedModelDeploymentAction.class); private final XPackLicenseState licenseState; private final Client client; - private final PersistentTasksService persistentTasksService; + private final TrainedModelAllocationService trainedModelAllocationService; private final NamedXContentRegistry xContentRegistry; private final MlMemoryTracker memoryTracker; @@ -79,31 +70,32 @@ public class TransportStartTrainedModelDeploymentAction public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, IndexNameExpressionResolver indexNameExpressionResolver, - PersistentTasksService persistentTasksService, + TrainedModelAllocationService trainedModelAllocationService, NamedXContentRegistry xContentRegistry, MlMemoryTracker memoryTracker) { super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters, - StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, + StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, CreateTrainedModelAllocationAction.Response::new, ThreadPool.Names.SAME); this.licenseState = Objects.requireNonNull(licenseState); this.client = Objects.requireNonNull(client); - this.persistentTasksService = Objects.requireNonNull(persistentTasksService); this.xContentRegistry = Objects.requireNonNull(xContentRegistry); this.memoryTracker = Objects.requireNonNull(memoryTracker); + this.trainedModelAllocationService = Objects.requireNonNull(trainedModelAllocationService); } @Override protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Request request, ClusterState state, - ActionListener listener) throws Exception { - logger.debug(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); + ActionListener listener) throws Exception { + logger.trace(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) { listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); return; } - ActionListener> waitForDeploymentToStart = + ActionListener waitForDeploymentToStart = ActionListener.wrap( - startedTask -> waitForDeploymentStarted(startedTask, request.getTimeout(), listener), + modelAllocation -> waitForDeploymentStarted(request.getModelId(), request.getTimeout(), listener), e -> { + logger.warn(() -> new ParameterizedMessage("[{}] creating new allocation failed", request.getModelId()), e); if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { e = new ElasticsearchStatusException( "Cannot start deployment [{}] because it has already been started", @@ -150,9 +142,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom( PersistentTasksCustomMetadata.TYPE); memoryTracker.refresh(persistentTasks, ActionListener.wrap( - aVoid -> persistentTasksService.sendStartRequest( - MlTasks.trainedModelDeploymentTaskId(request.getModelId()), - MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + aVoid -> trainedModelAllocationService.createNewModelAllocation( taskParams, waitForDeploymentToStart ), @@ -162,6 +152,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ }, listener::onFailure )); + }, listener::onFailure ); @@ -191,17 +182,20 @@ private void getModelBytes(TrainedModelConfig trainedModelConfig, ActionListener ); } - private void waitForDeploymentStarted(PersistentTasksCustomMetadata.PersistentTask task, - TimeValue timeout, ActionListener listener) { - DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(); - persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, timeout, - new PersistentTasksService.WaitForPersistentTaskListener() { + private void waitForDeploymentStarted( + String modelId, + TimeValue timeout, + ActionListener listener + ) { + DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId); + trainedModelAllocationService.waitForAllocationCondition(modelId, predicate, timeout, + new TrainedModelAllocationService.WaitForAllocationListener() { @Override - public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + public void onResponse(TrainedModelAllocation allocation) { if (predicate.exception != null) { - cancelFailedDeployment(task, predicate.exception, listener); + deleteFailedDeployment(modelId, predicate.exception, listener); } else { - listener.onResponse(new NodeAcknowledgedResponse(true, predicate.node)); + listener.onResponse(new CreateTrainedModelAllocationAction.Response(allocation)); } } @@ -212,15 +206,20 @@ public void onFailure(Exception e) { }); } - private void cancelFailedDeployment( - PersistentTasksCustomMetadata.PersistentTask persistentTask, Exception exception, - ActionListener listener) { - persistentTasksService.sendRemoveRequest(persistentTask.getId(), ActionListener.wrap( + private void deleteFailedDeployment( + String modelId, + Exception exception, + ActionListener listener + ) { + trainedModelAllocationService.deleteModelAllocation(modelId, ActionListener.wrap( pTask -> listener.onFailure(exception), e -> { logger.error( - new ParameterizedMessage("[{}] Failed to cancel persistent task that had failed with the reason [{}]", - persistentTask.getParams().getModelId(), exception.getMessage()), + new ParameterizedMessage( + "[{}] Failed to delete model allocation that had failed with the reason [{}]", + modelId, + exception.getMessage() + ), e ); listener.onFailure(exception); @@ -237,148 +236,56 @@ protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Req return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); } - private static class DeploymentStartedPredicate implements Predicate> { + private static class DeploymentStartedPredicate implements Predicate { private volatile Exception exception; - private volatile String node = ""; + + // for logging + private final String modelId; + + DeploymentStartedPredicate(String modelId) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); + } @Override - public boolean test(PersistentTasksCustomMetadata.PersistentTask persistentTask) { - if (persistentTask == null) { - return false; + public boolean test(TrainedModelAllocation trainedModelAllocation) { + if (trainedModelAllocation == null) { + // Something weird happened, it should NEVER be null... + return true; } - PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment(); - - String reason = "__unknown__"; + final Set> nodesAndState = trainedModelAllocation + .getNodeRoutingTable() + .entrySet(); - if (assignment != null) { - if (assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { - return true; + Map nodeFailuresAndReasons = new HashMap<>(); + Set nodesStillInitializing = new HashSet<>(); + for (Map.Entry nodeIdAndState : nodesAndState) { + if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) { + nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason()); } - if (assignment.equals(PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) { - exception = new ElasticsearchStatusException("Could not start trained model deployment, allocation explanation [{}]", - RestStatus.TOO_MANY_REQUESTS, assignment.getExplanation()); - return true; + if (RoutingState.STARTING.equals(nodeIdAndState.getValue().getState())) { + nodesStillInitializing.add(nodeIdAndState.getKey()); } } - TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) persistentTask.getState(); - reason = taskState != null ? taskState.getReason() : reason; - TrainedModelDeploymentState deploymentState = taskState == null ? TrainedModelDeploymentState.STARTING : taskState.getState(); - - switch (deploymentState) { - case STARTED: - node = persistentTask.getExecutorNode(); - return true; - case STARTING: - case STOPPING: - case STOPPED: - return false; - case FAILED: - exception = ExceptionsHelper.serverError("Deployment failed with reason: {}", reason); - return true; - default: - exception = ExceptionsHelper.serverError("Unexpected task state [{}] with reason [{}] while waiting to be started", - taskState.getState(), reason); - return true; + if (nodeFailuresAndReasons.isEmpty() == false) { + exception = new ElasticsearchStatusException( + "Could not start trained model deployment, the following nodes failed with errors [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + nodeFailuresAndReasons + ); + return true; } - } - } - - public static class TaskExecutor extends AbstractJobPersistentTasksExecutor { - - private final DeploymentManager manager; - - public TaskExecutor(Settings settings, ClusterService clusterService, IndexNameExpressionResolver expressionResolver, - MlMemoryTracker memoryTracker, DeploymentManager manager) { - super(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, - MachineLearning.UTILITY_THREAD_POOL_NAME, - settings, - clusterService, - memoryTracker, - expressionResolver); - this.manager = Objects.requireNonNull(manager); - } - @Override - protected AllocatedPersistentTask createTask( - long id, String type, String action, TaskId parentTaskId, - PersistentTasksCustomMetadata.PersistentTask persistentTask, - Map headers) { - return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); - } - - @Override - public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params, - Collection candidateNodes, - ClusterState clusterState) { - - boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); - Optional optionalAssignment = - getPotentialAssignment(params, clusterState, isMemoryTrackerRecentlyRefreshed); - // NOTE: this will return here if isMemoryTrackerRecentlyRefreshed is false, we don't allow assignment with stale memory - if (optionalAssignment.isPresent()) { - return optionalAssignment.get(); + if (nodesStillInitializing.isEmpty()) { + return true; } - - JobNodeSelector jobNodeSelector = - new JobNodeSelector( - clusterState, - candidateNodes, - params.getModelId(), - MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, - memoryTracker, - maxLazyMLNodes, - node -> nodeFilter(node, params) - ); - - PersistentTasksCustomMetadata.Assignment assignment = jobNodeSelector.selectNode( - params.estimateMemoryUsageBytes(), - maxOpenJobs, - Integer.MAX_VALUE, - maxMachineMemoryPercent, - maxNodeMemory, - useAutoMemoryPercentage + logger.trace( + () -> new ParameterizedMessage("[{}] tested and nodes {} still initializing", modelId, nodesStillInitializing) ); - return assignment; - } - - public static String nodeFilter(DiscoveryNode node, TaskParams params) { - String id = params.getModelId(); - - if (node.getVersion().before(TaskParams.VERSION_INTRODUCED)) { - return "Not opening job [" + id + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) - + "], because the data frame analytics requires a node of version [" - + TaskParams.VERSION_INTRODUCED + "] or higher"; - } - - return null; - } - - @Override - protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) { - TrainedModelDeploymentTask trainedModelDeploymentTask = (TrainedModelDeploymentTask) task; - trainedModelDeploymentTask.setDeploymentManager(manager); - - TrainedModelDeploymentTaskState deployingState = new TrainedModelDeploymentTaskState( - TrainedModelDeploymentState.STARTING, task.getAllocationId(), null); - task.updatePersistentTaskState(deployingState, ActionListener.wrap( - response -> manager.startDeployment(trainedModelDeploymentTask), - task::markAsFailed - )); - } - - @Override - protected String[] indicesOfInterest(TaskParams params) { - return new String[] { - InferenceIndexConstants.INDEX_PATTERN - }; - } - - @Override - protected String getJobId(TaskParams params) { - return params.getModelId(); + return false; } } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 30b3d663dd5b6..7b8753e421034 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; @@ -17,51 +18,59 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.discovery.MasterNotDiscoveredException; -import org.elasticsearch.persistent.PersistentTasksCustomMetadata; -import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Set; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +/** + * Class for transporting stop trained model deloyment requests. + * + * NOTE: this class gets routed to each individual deployment running on the nodes. This way when the request returns, we are assured + * that the model is not running any longer on any node. + */ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAction { private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); private final Client client; - private final ThreadPool threadPool; - private final PersistentTasksService persistentTasksService; + private final TrainedModelAllocationService trainedModelAllocationService; + private final TrainedModelAllocationClusterService trainedModelAllocationClusterService; @Inject public TransportStopTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters, Client client, ThreadPool threadPool, - PersistentTasksService persistentTasksService) { + ActionFilters actionFilters, Client client, + TrainedModelAllocationService trainedModelAllocationService, + TrainedModelAllocationClusterService trainedModelAllocationClusterService) { super(StopTrainedModelDeploymentAction.NAME, clusterService, transportService, actionFilters, StopTrainedModelDeploymentAction.Request::new, StopTrainedModelDeploymentAction.Response::new, StopTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); - this.client = client; - this.threadPool = threadPool; - this.persistentTasksService = persistentTasksService; + this.client = new OriginSettingClient(client, ML_ORIGIN); + this.trainedModelAllocationService = trainedModelAllocationService; + this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; } @Override @@ -69,6 +78,7 @@ protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request req ActionListener listener) { ClusterState state = clusterService.state(); DiscoveryNodes nodes = state.nodes(); + // Master node is required for initial pre-checks and deletion preparation if (nodes.isLocalNodeElectedMaster() == false) { redirectToMasterNode(nodes.getMasterNode(), request, listener); return; @@ -88,15 +98,30 @@ protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request req return; } - ClusterState clusterState = clusterService.state(); - PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); - PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = - MlTasks.getTrainedModelDeploymentTask(request.getId(), tasks); - if (deployTrainedModelTask == null) { + Optional maybeAllocation = TrainedModelAllocationMetadata.allocationForModelId( + clusterService.state(), + models.get(0).getModelId() + ); + + if (maybeAllocation.isEmpty()) { listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); return; } - normalUndeploy(task, deployTrainedModelTask, request, listener); + final String modelId = models.get(0).getModelId(); + // NOTE, should only run on Master node + trainedModelAllocationClusterService.setModelAllocationToStopping( + modelId, + ActionListener.wrap( + setToStopping -> normalUndeploy(task, models.get(0).getModelId(), maybeAllocation.get(), request, listener), + failure -> { + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; + } + listener.onFailure(failure); + } + ) + ); }, listener::onFailure ); @@ -117,13 +142,39 @@ private void redirectToMasterNode(DiscoveryNode masterNode, StopTrainedModelDepl } } - private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask, + private void normalUndeploy(Task task, + String modelId, + TrainedModelAllocation modelAllocation, StopTrainedModelDeploymentAction.Request request, ActionListener listener) { - request.setNodes(deployTrainedModelTask.getExecutorNode()); - + request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new)); ActionListener finalListener = ActionListener.wrap( - r -> waitForTaskRemoved(Collections.singleton(deployTrainedModelTask.getId()), request, r, listener), + r -> { + waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap( + waited -> { + trainedModelAllocationService.deleteModelAllocation( + modelId, + ActionListener.wrap( + deleted -> listener.onResponse(r), + deletionFailed -> { + logger.error( + () -> new ParameterizedMessage( + "[{}] failed to delete model allocation after nodes unallocated the deployment", + modelId + ),deletionFailed); + listener.onFailure(ExceptionsHelper.serverError( + "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again", + deletionFailed + )); + } + ) + ); + }, + // TODO should we attempt to delete the deployment here? + listener::onFailure + )); + + }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { // A node has dropped out of the cluster since we started executing the requests. @@ -137,22 +188,26 @@ private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentT } } ); - super.doExecute(task, request, finalListener); } - void waitForTaskRemoved(Set taskIds, StopTrainedModelDeploymentAction.Request request, + void waitForTaskRemoved(String modelId, + TrainedModelAllocation trainedModelAllocation, + StopTrainedModelDeploymentAction.Request request, StopTrainedModelDeploymentAction.Response response, ActionListener listener) { - persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> - persistentTasks.findTasks(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), - request.getTimeout(), ActionListener.wrap( - booleanResponse -> { - listener.onResponse(response); - }, + final Set nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet(); + client.admin() + .cluster() + .prepareListTasks(nodesOfConcern.toArray(String[]::new)) + .setDetailed(true) + .setWaitForCompletion(true) + .setActions(modelId) + .setTimeout(request.getTimeout()) + .execute(ActionListener.wrap( + complete -> listener.onResponse(response), listener::onFailure - ) - ); + )); } @Override @@ -172,31 +227,7 @@ protected StopTrainedModelDeploymentAction.Response newResponse(StopTrainedModel @Override protected void taskOperation(StopTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, ActionListener listener) { - TrainedModelDeploymentTaskState undeployingState = new TrainedModelDeploymentTaskState( - TrainedModelDeploymentState.STOPPING, task.getAllocationId(), "api"); - task.updatePersistentTaskState(undeployingState, ActionListener.wrap( - updatedTask -> { - threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - - @Override - protected void doRun() throws Exception { - task.stop("undeploy_trained_model (api)"); - listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); - } - }); - }, - e -> { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { - // the task has disappeared so must have stopped - listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); - } else { - listener.onFailure(e); - } - } - )); + task.stop("undeploy_trained_model (api)"); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAllocationStateAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAllocationStateAction.java new file mode 100644 index 0000000000000..3c87ab800d50f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAllocationStateAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction.Request; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; + +public class TransportUpdateTrainedModelAllocationStateAction extends AcknowledgedTransportMasterNodeAction { + + private final TrainedModelAllocationClusterService trainedModelAllocationClusterService; + + @Inject + public TransportUpdateTrainedModelAllocationStateAction( + TrainedModelAllocationClusterService trainedModelAllocationClusterService, + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + UpdateTrainedModelAllocationStateAction.NAME, + false, + transportService, + clusterService, + threadPool, + actionFilters, + Request::new, + indexNameExpressionResolver, + ThreadPool.Names.SAME + ); + this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; + } + + @Override + protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) + throws Exception { + trainedModelAllocationClusterService.updateModelRoutingTable(request, listener); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java new file mode 100644 index 0000000000000..8bb3e9e25b0ff --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java @@ -0,0 +1,407 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.ClusterStateUpdateTask; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.NodesShutdownMetadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.gateway.GatewayService; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.NodeLoad; +import org.elasticsearch.xpack.ml.job.NodeLoadDetector; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class TrainedModelAllocationClusterService implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(TrainedModelAllocationClusterService.class); + + private final ClusterService clusterService; + private final NodeLoadDetector nodeLoadDetector; + private volatile int maxMemoryPercentage; + private volatile boolean useAuto; + + public TrainedModelAllocationClusterService(Settings settings, ClusterService clusterService, NodeLoadDetector nodeLoadDetector) { + this.clusterService = clusterService; + this.nodeLoadDetector = nodeLoadDetector; + this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); + this.useAuto = MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings); + // Only nodes that can possibly be master nodes really need this service running + if (DiscoveryNode.isMasterNode(settings)) { + clusterService.addListener(this); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMemoryPercentage); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto); + } + } + + private void setMaxMemoryPercentage(int maxMemoryPercentage) { + this.maxMemoryPercentage = maxMemoryPercentage; + } + + private void setUseAuto(boolean useAuto) { + this.useAuto = useAuto; + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + return; + } + if (event.localNodeMaster() && shouldAllocateModels(event)) { + clusterService.submitStateUpdateTask("allocating models to nodes", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + // TODO this has a weird side-effect for allocating to nodes + // If the event indicates there were nodes added/removed, this method only looks at the current state and has + // no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node + // it may get re-allocated to that node when another node is added/removed... + return addRemoveAllocationNodes(currentState); + } + + @Override + public void onFailure(String source, Exception e) { + logger.warn("failed to allocate models", e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + logger.trace( + () -> new ParameterizedMessage( + "updated model allocations based on node changes in the cluster; new metadata [{}]", + Strings.toString(TrainedModelAllocationMetadata.fromState(newState), false, true) + ) + ); + } + }); + } + } + + public void updateModelRoutingTable( + UpdateTrainedModelAllocationStateAction.Request request, + ActionListener listener + ) { + clusterService.submitStateUpdateTask("updating model routing for node allocation", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + return updateModelRoutingTable(currentState, request); + } + + @Override + public void onFailure(String source, Exception e) { + listener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + listener.onResponse(AcknowledgedResponse.TRUE); + } + }); + } + + public void createNewModelAllocation( + StartTrainedModelDeploymentAction.TaskParams params, + ActionListener listener + ) { + clusterService.submitStateUpdateTask("create model allocation", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + return createModelAllocation(currentState, params); + } + + @Override + public void onFailure(String source, Exception e) { + listener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + listener.onResponse(TrainedModelAllocationMetadata.fromState(newState).getModelAllocation(params.getModelId())); + } + }); + } + + public void setModelAllocationToStopping(String modelId, ActionListener listener) { + clusterService.submitStateUpdateTask("set model allocation stopping", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + return setToStopping(currentState, modelId); + } + + @Override + public void onFailure(String source, Exception e) { + listener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + listener.onResponse(AcknowledgedResponse.TRUE); + } + }); + } + + public void removeModelAllocation(String modelId, ActionListener listener) { + clusterService.submitStateUpdateTask("delete model allocation", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + return removeAllocation(currentState, modelId); + } + + @Override + public void onFailure(String source, Exception e) { + listener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + listener.onResponse(AcknowledgedResponse.TRUE); + } + }); + } + + private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) { + if (modelAllocations.isChanged()) { + return ClusterState.builder(currentState) + .metadata( + Metadata.builder(currentState.metadata()).putCustom(TrainedModelAllocationMetadata.NAME, modelAllocations.build()) + ) + .build(); + } else { + return currentState; + } + } + + ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) { + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); + if (builder.hasModel(params.getModelId())) { + throw new ResourceAlreadyExistsException("allocation for model with id [" + params.getModelId() + "] already exist"); + } + + Set shuttingDownNodes = nodesShuttingDown(currentState); + builder.addNewAllocation(params); + for (DiscoveryNode node : currentState.getNodes().getAllNodes()) { + if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) + && shuttingDownNodes.contains(node.getId()) == false) { + Optional maybeError = nodeHasCapacity(currentState, params, node); + if (maybeError.isPresent()) { + builder.addFailedNode(params.getModelId(), node.getId(), maybeError.get()); + } else { + builder.addNode(params.getModelId(), node.getId()); + } + } + } + return update(currentState, builder); + } + + static ClusterState setToStopping(ClusterState clusterState, String modelId) { + TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.fromState(clusterState); + final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId); + if (existingAllocation == null) { + throw new ResourceNotFoundException("allocation for model with id [{}] not found", modelId); + } + // If we are stopping, don't update anything + if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) { + return clusterState; + } + + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(clusterState); + builder.setAllocationToStopping(modelId); + return update(clusterState, builder); + } + + static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAllocationStateAction.Request request) { + final String modelId = request.getModelId(); + final String nodeId = request.getNodeId(); + TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.fromState(currentState); + logger.trace( + () -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata)) + ); + final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId); + + // If state is stopped, this indicates the node process is closed, remove the node from the allocation + if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) { + if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) { + return currentState; + } + return update(currentState, TrainedModelAllocationMetadata.builder(currentState).removeNode(modelId, nodeId)); + } + + if (existingAllocation == null) { + throw new ResourceNotFoundException("allocation for model with id [{}] not found", modelId); + } + // If we are stopping, don't update anything + if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) { + logger.debug(() -> new ParameterizedMessage( + "[{}] requested update from node [{}] to update route state to [{}]", + modelId, + nodeId, + request.getRoutingState() + )); + return currentState; + } + if (existingAllocation.isRoutedToNode(nodeId) == false) { + throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId); + } + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); + builder.updateAllocation(modelId, nodeId, request.getRoutingState()); + return update(currentState, builder); + } + + static ClusterState removeAllocation(ClusterState currentState, String modelId) { + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); + if (builder.hasModel(modelId) == false) { + throw new ResourceNotFoundException("allocation for model with id [{}] not found", modelId); + } + return update(currentState, builder.removeAllocation(modelId)); + } + + ClusterState addRemoveAllocationNodes(ClusterState currentState) { + TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState); + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); + Map> removedNodeModelLookUp = new HashMap<>(); + Set shuttingDownNodes = nodesShuttingDown(currentState); + // TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes) + // It could probably be O(max(n, m)) + // Add nodes and keep track of currently routed nodes + // Should we indicate a partial allocation somehow if some nodes don't have space? + for (Map.Entry modelAllocationEntry : previousState.modelAllocations().entrySet()) { + // Don't bother adding/removing nodes if this allocation is stopping + if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) { + continue; + } + for (DiscoveryNode node : currentState.getNodes()) { + // Only add the route if the node is NOT shutting down, this would be a weird case of the node + // just being added to the cluster and immediately shutting down... + if (shuttingDownNodes.contains(node.getId()) == false + && StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) + && modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) { + nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node).ifPresentOrElse( + (error) -> builder.addFailedNode(modelAllocationEntry.getKey(), node.getId(), error), + () -> builder.addNode(modelAllocationEntry.getKey(), node.getId()) + ); + } + } + for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) { + removedNodeModelLookUp.computeIfAbsent(nodeId, k -> new ArrayList<>()).add(modelAllocationEntry.getKey()); + } + } + + // Remove nodes + currentState.getNodes() + .forEach( + d -> { + // If a node is referenced in the current state, we shouldn't remove the node + // But, if that node that is referenced is shutting down, we should remove the node + if (shuttingDownNodes.contains(d.getId()) == false) { + removedNodeModelLookUp.remove(d.getId()); + } + } + ); + for (Map.Entry> nodeToModels : removedNodeModelLookUp.entrySet()) { + final String nodeId = nodeToModels.getKey(); + for (String modelId : nodeToModels.getValue()) { + builder.removeNode(modelId, nodeId); + } + } + return update(currentState, builder); + } + + static boolean shouldAllocateModels(final ClusterChangedEvent event) { + // If there are no allocations created at all, there is nothing to update + final TrainedModelAllocationMetadata newMetadata = event.state().getMetadata().custom(TrainedModelAllocationMetadata.NAME); + if (newMetadata == null) { + return false; + } + if (event.nodesChanged()) { + Set shuttingDownNodes = nodesShuttingDown(event.state()); + DiscoveryNodes.Delta nodesDelta = event.nodesDelta(); + for (TrainedModelAllocation trainedModelAllocation : newMetadata.modelAllocations().values()) { + if (trainedModelAllocation.getAllocationState().equals(AllocationState.STOPPING)) { + continue; + } + for (DiscoveryNode removed : nodesDelta.removedNodes()) { + if (trainedModelAllocation.isRoutedToNode(removed.getId())) { + return true; + } + } + for (DiscoveryNode added : nodesDelta.addedNodes()) { + if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(added) + && shuttingDownNodes.contains(added.getId()) == false) { + return true; + } + } + } + } + return false; + } + + Optional nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) { + NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, Integer.MAX_VALUE, maxMemoryPercentage, useAuto); + if (Strings.isNullOrEmpty(load.getError()) == false) { + logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), node.getId()); + return Optional.of(load.getError()); + } + if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) { + return Optional.of( + ParameterizedMessage.format( + "This node has insufficient available memory. Available memory for ML [{} ({})], " + + "memory required by existing jobs and models [{} ({})], " + + "estimated memory required for this model [{} ({})].", + new Object[] { + load.getMaxMlMemory(), + ByteSizeValue.ofBytes(load.getMaxMlMemory()).toString(), + load.getAssignedJobMemory(), + ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(), + params.estimateMemoryUsageBytes(), + ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() } + ) + ); + } + return Optional.empty(); + } + + /** + * Returns true if the given node is marked as shutting down with any + * shutdown type. + */ + static Set nodesShuttingDown(final ClusterState state) { + return NodesShutdownMetadata.getShutdowns(state) + .map(NodesShutdownMetadata::getAllNodeMetadataMap) + .map(Map::keySet) + .orElse(Collections.emptySet()); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java new file mode 100644 index 0000000000000..fd2386c8985ac --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java @@ -0,0 +1,278 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.AbstractDiffable; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.EnumSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.TreeMap; + +import static org.elasticsearch.cluster.metadata.Metadata.ALL_CONTEXTS; + +public class TrainedModelAllocationMetadata implements Metadata.Custom { + + private static final TrainedModelAllocationMetadata EMPTY = new TrainedModelAllocationMetadata(Collections.emptyMap()); + public static final String NAME = "trained_model_allocation"; + private final Map modelRoutingEntries; + + public static TrainedModelAllocationMetadata fromXContent(XContentParser parser) throws IOException { + return new TrainedModelAllocationMetadata(parser.map(LinkedHashMap::new, TrainedModelAllocation::fromXContent)); + } + + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + return new TrainedModelAllocationMetadata.TrainedModeAllocationDiff(in); + } + + public static Builder builder(ClusterState clusterState) { + return Builder.fromMetadata(fromState(clusterState)); + } + + public static TrainedModelAllocationMetadata fromState(ClusterState clusterState) { + TrainedModelAllocationMetadata trainedModelAllocationMetadata = clusterState.getMetadata().custom(NAME); + return trainedModelAllocationMetadata == null ? EMPTY : trainedModelAllocationMetadata; + } + + public static Optional allocationForModelId(ClusterState clusterState, String modelId) { + return Optional.ofNullable(TrainedModelAllocationMetadata.fromState(clusterState)) + .map(metadata -> metadata.getModelAllocation(modelId)); + } + + public TrainedModelAllocationMetadata(Map modelRoutingEntries) { + this.modelRoutingEntries = ExceptionsHelper.requireNonNull(modelRoutingEntries, NAME); + } + + public TrainedModelAllocationMetadata(StreamInput in) throws IOException { + this.modelRoutingEntries = in.readOrderedMap(StreamInput::readString, TrainedModelAllocation::new); + } + + public TrainedModelAllocation getModelAllocation(String modelId) { + return modelRoutingEntries.get(modelId); + } + + public Map modelAllocations() { + return Collections.unmodifiableMap(modelRoutingEntries); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.mapContents(modelRoutingEntries); + return builder; + } + + @Override + public Diff diff(Metadata.Custom previousState) { + return new TrainedModeAllocationDiff((TrainedModelAllocationMetadata) previousState, this); + } + + @Override + public EnumSet context() { + return ALL_CONTEXTS; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_0_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(modelRoutingEntries, StreamOutput::writeString, (o, w) -> w.writeTo(o)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelAllocationMetadata that = (TrainedModelAllocationMetadata) o; + return Objects.equals(modelRoutingEntries, that.modelRoutingEntries); + } + + @Override + public int hashCode() { + return Objects.hash(modelRoutingEntries); + } + + public static class Builder { + + public static Builder empty(){ + return new Builder(); + } + + private final Map modelRoutingEntries; + private boolean isChanged; + + public static Builder fromMetadata(TrainedModelAllocationMetadata modelAllocationMetadata) { + return new Builder(modelAllocationMetadata); + } + + private Builder() { + modelRoutingEntries = new LinkedHashMap<>(); + } + + private Builder(TrainedModelAllocationMetadata modelAllocationMetadata) { + this.modelRoutingEntries = new LinkedHashMap<>(); + modelAllocationMetadata.modelRoutingEntries.forEach( + (modelId, allocation) -> modelRoutingEntries.put(modelId, TrainedModelAllocation.Builder.fromAllocation(allocation)) + ); + } + + public boolean hasModel(String modelId) { + return modelRoutingEntries.containsKey(modelId); + } + + Builder addNewAllocation(StartTrainedModelDeploymentAction.TaskParams taskParams) { + if (modelRoutingEntries.containsKey(taskParams.getModelId())) { + return this; + } + modelRoutingEntries.put(taskParams.getModelId(), TrainedModelAllocation.Builder.empty(taskParams)); + isChanged = true; + return this; + } + + Builder updateAllocation(String modelId, String nodeId, RoutingStateAndReason state) { + TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); + if (allocation == null) { + return this; + } + isChanged |= allocation.updateExistingRoutingEntry(nodeId, state).isChanged(); + return this; + } + + Builder addNode(String modelId, String nodeId) { + TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); + if (allocation == null) { + throw new ResourceNotFoundException( + "unable to add node [{}] to model [{}] routing table as allocation does not exist", + nodeId, + modelId + ); + } + isChanged |= allocation.addNewRoutingEntry(nodeId).isChanged(); + return this; + } + + Builder addFailedNode(String modelId, String nodeId, String reason) { + TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); + if (allocation == null) { + throw new ResourceNotFoundException( + "unable to add failed node [{}] to model [{}] routing table as allocation does not exist", + nodeId, + modelId + ); + } + isChanged |= allocation.addNewFailedRoutingEntry(nodeId, reason).isChanged(); + return this; + } + + Builder removeNode(String modelId, String nodeId) { + TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); + if (allocation == null) { + return this; + } + isChanged |= allocation.removeRoutingEntry(nodeId).isChanged(); + return this; + } + + public Builder removeAllocation(String modelId) { + isChanged |= modelRoutingEntries.remove(modelId) != null; + return this; + } + + public Builder setAllocationToStopping(String modelId) { + TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); + if (allocation == null) { + throw new ResourceNotFoundException( + "unable to set model allocation [{}] to stopping as it does not exist", + modelId + ); + } + isChanged |= allocation.stopAllocation().isChanged(); + return this; + } + + public boolean isChanged() { + return isChanged; + } + + public TrainedModelAllocationMetadata build() { + Map allocations = new LinkedHashMap<>(); + modelRoutingEntries.forEach((modelId, allocation) -> allocations.put(modelId, allocation.build())); + return new TrainedModelAllocationMetadata(allocations); + } + } + + public static class TrainedModeAllocationDiff implements NamedDiff { + + private final Diff> modelRoutingEntries; + + static Diff readFrom(final StreamInput in) throws IOException { + return AbstractDiffable.readDiffFrom(TrainedModelAllocation::new, in); + } + + public TrainedModeAllocationDiff(TrainedModelAllocationMetadata before, TrainedModelAllocationMetadata after) { + this.modelRoutingEntries = DiffableUtils.diff( + before.modelRoutingEntries, + after.modelRoutingEntries, + DiffableUtils.getStringKeySerializer() + ); + } + + public TrainedModeAllocationDiff(final StreamInput in) throws IOException { + this.modelRoutingEntries = DiffableUtils.readJdkMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + TrainedModelAllocation::new, + TrainedModeAllocationDiff::readFrom + ); + } + + @Override + public Metadata.Custom apply(Metadata.Custom part) { + return new TrainedModelAllocationMetadata( + new TreeMap<>(modelRoutingEntries.apply(((TrainedModelAllocationMetadata) part).modelRoutingEntries)) + ); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + modelRoutingEntries.writeTo(out); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java new file mode 100644 index 0000000000000..b6bc19be9e206 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -0,0 +1,380 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.component.LifecycleListener; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskAwareRequest; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; + +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; + +public class TrainedModelAllocationNodeService implements ClusterStateListener { + + private static final String TASK_NAME = "trained_model_allocation"; + private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds(1); + private static final Logger logger = LogManager.getLogger(TrainedModelAllocationNodeService.class); + private final TrainedModelAllocationService trainedModelAllocationService; + private final DeploymentManager deploymentManager; + private final TaskManager taskManager; + private final Map modelIdToTask; + private final ThreadPool threadPool; + private final Deque loadingModels; + private volatile Scheduler.Cancellable scheduledFuture; + private volatile boolean stopped; + private volatile String nodeId; + + public TrainedModelAllocationNodeService( + TrainedModelAllocationService trainedModelAllocationService, + ClusterService clusterService, + DeploymentManager deploymentManager, + TaskManager taskManager, + ThreadPool threadPool + ) { + this.trainedModelAllocationService = trainedModelAllocationService; + this.deploymentManager = deploymentManager; + this.taskManager = taskManager; + this.modelIdToTask = new ConcurrentHashMap<>(); + this.loadingModels = new ConcurrentLinkedDeque<>(); + this.threadPool = threadPool; + clusterService.addLifecycleListener(new LifecycleListener() { + @Override + public void afterStart() { + nodeId = clusterService.localNode().getId(); + start(); + } + + @Override + public void beforeStop() { + stop(); + } + }); + } + + TrainedModelAllocationNodeService( + TrainedModelAllocationService trainedModelAllocationService, + ClusterService clusterService, + DeploymentManager deploymentManager, + TaskManager taskManager, + ThreadPool threadPool, + String nodeId + ) { + this.trainedModelAllocationService = trainedModelAllocationService; + this.deploymentManager = deploymentManager; + this.taskManager = taskManager; + this.modelIdToTask = new ConcurrentHashMap<>(); + this.loadingModels = new ConcurrentLinkedDeque<>(); + this.threadPool = threadPool; + this.nodeId = nodeId; + clusterService.addLifecycleListener(new LifecycleListener() { + @Override + public void afterStart() { + start(); + } + + @Override + public void beforeStop() { + stop(); + } + }); + } + + void stopDeployment(TrainedModelDeploymentTask task) { + if (stopped) { + return; + } + deploymentManager.stopDeployment(task); + taskManager.unregister(task); + modelIdToTask.remove(task.getModelId()); + } + + void stopDeploymentAsync(TrainedModelDeploymentTask task, ActionListener listener) { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + try { + stopDeployment(task); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + }); + } + + public void start() { + stopped = false; + scheduledFuture = threadPool.scheduleWithFixedDelay( + this::loadQueuedModels, + MODEL_LOADING_CHECK_INTERVAL, + MachineLearning.UTILITY_THREAD_POOL_NAME + ); + } + + public void stop() { + stopped = true; + ThreadPool.Cancellable cancellable = this.scheduledFuture; + if (cancellable != null) { + cancellable.cancel(); + } + } + + void loadQueuedModels() { + TrainedModelDeploymentTask loadingTask; + logger.trace("attempting to load all currently queued models"); + // NOTE: As soon as this method exits, the timer for the scheduler starts ticking + while ((loadingTask = loadingModels.poll()) != null) { + if (loadingTask.isStopped()) { + continue; + } + if (stopped) { + return; + } + final String modelId = loadingTask.getModelId(); + logger.trace(() -> new ParameterizedMessage("[{}] attempting to load model", modelId)); + final PlainActionFuture listener = new PlainActionFuture<>(); + deploymentManager.startDeployment(loadingTask, listener); + try { + // This needs to be synchronous here in the utility thread to keep queueing order + TrainedModelDeploymentTask deployedTask = listener.actionGet(); + // kicks off asynchronous cluster state update + handleLoadSuccess(deployedTask); + } catch (Exception ex) { + // kicks off asynchronous cluster state update + handleLoadFailure(loadingTask, ex); + } + } + } + + public void stopDeploymentAndNotify(TrainedModelDeploymentTask task) { + ActionListener notifyDeploymentOfStopped = ActionListener.wrap( + stopped -> updateStoredState( + task.getModelId(), + new RoutingStateAndReason(RoutingState.STOPPED, ""), + ActionListener.wrap(s -> {}, failure -> {}) + ), + failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop + logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", task.getModelId()), failed); + updateStoredState( + task.getModelId(), + new RoutingStateAndReason(RoutingState.STOPPED, ""), + ActionListener.wrap(s -> {}, failure -> {}) + ); + } + ); + updateStoredState( + task.getModelId(), + new RoutingStateAndReason(RoutingState.STOPPING, "task locally canceled"), + ActionListener.wrap(success -> stopDeploymentAsync(task, notifyDeploymentOfStopped), e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + logger.debug( + () -> new ParameterizedMessage( + "[{}] failed to set routing state to stopping as allocation already removed", + task.getModelId() + ), + e + ); + } else { + // this is an unexpected error + // TODO this means requests may still be routed here, should we not stop deployment? + logger.warn( + () -> new ParameterizedMessage("[{}] failed to set routing state to stopping due to error", task.getModelId()), + e + ); + } + stopDeploymentAsync(task, notifyDeploymentOfStopped); + }) + ); + } + + public void infer(TrainedModelDeploymentTask task, String input, TimeValue timeout, ActionListener listener) { + deploymentManager.infer(task, input, timeout, listener); + } + + private TaskAwareRequest taskAwareRequest(StartTrainedModelDeploymentAction.TaskParams params) { + final TrainedModelAllocationNodeService trainedModelAllocationNodeService = this; + return new TaskAwareRequest() { + @Override + public void setParentTask(TaskId taskId) { + throw new UnsupportedOperationException("parent task id for model allocation tasks shouldn't change"); + } + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService); + } + }; + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.metadataChanged()) { + TrainedModelAllocationMetadata modelAllocationMetadata = TrainedModelAllocationMetadata.fromState(event.state()); + final String currentNode = event.state().nodes().getLocalNodeId(); + for (TrainedModelAllocation trainedModelAllocation : modelAllocationMetadata.modelAllocations().values()) { + RoutingStateAndReason routingStateAndReason = trainedModelAllocation.getNodeRoutingTable().get(currentNode); + // Add new models to start loading + if (routingStateAndReason != null + // periodic retries should be handled in a separate thread think + && routingStateAndReason.getState().equals(RoutingState.STARTING) + // This means we don't already have a task and should attempt creating one and starting the model loading + && modelIdToTask.containsKey(trainedModelAllocation.getTaskParams().getModelId()) == false) { + prepareModelToLoad(trainedModelAllocation.getTaskParams()); + } + // This mode is not routed to the current node at all + if (routingStateAndReason == null) { + TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId()); + if (task != null) { + task.stopWithoutNotification("node no longer referenced in model routing table"); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> stopDeployment(task)); + } + } + } + List toCancel = new ArrayList<>(); + for (String modelIds : Sets.difference(modelIdToTask.keySet(), modelAllocationMetadata.modelAllocations().keySet())) { + toCancel.add(modelIdToTask.remove(modelIds)); + } + // should all be stopped in the same executor thread? + for (TrainedModelDeploymentTask t : toCancel) { + t.stopWithoutNotification("model allocation no longer exists"); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> stopDeployment(t)); + } + } + } + + // For testing purposes + TrainedModelDeploymentTask getTask(String modelId) { + return modelIdToTask.get(modelId); + } + + void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) { + TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register( + TASK_NAME, + taskParams.getModelId(), + taskAwareRequest(taskParams) + ); + // threadsafe check to verify we are not loading/loaded the model + if (modelIdToTask.putIfAbsent(taskParams.getModelId(), task) == null) { + loadingModels.add(task); + } else { + // If there is already a task for the model, unregister the new task + taskManager.unregister(task); + } + } + + private void handleLoadSuccess(TrainedModelDeploymentTask task) { + final String modelId = task.getModelId(); + logger.debug( + () -> new ParameterizedMessage("[{}] model successfully loaded and ready for inference. Notifying master node", modelId) + ); + if (task.isStopped()) { + logger.debug( + () -> new ParameterizedMessage("[{}] model loaded successfully, but stopped before routing table was updated", modelId) + ); + return; + } + updateStoredState( + modelId, + new RoutingStateAndReason(RoutingState.STARTED, ""), + ActionListener.wrap( + r -> logger.debug(() -> new ParameterizedMessage("[{}] model loaded and accepting routes", modelId)), + e -> { + // This means that either the allocation has been deleted, or this node's particular route has been removed + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + logger.debug( + () -> new ParameterizedMessage( + "[{}] model loaded but failed to start accepting routes as allocation to this node was removed", + modelId + ), + e + ); + } + // this is an unexpected error + logger.warn(() -> new ParameterizedMessage("[{}] model loaded but failed to start accepting routes", modelId), e); + } + ) + ); + } + + private void updateStoredState( + String modelId, + RoutingStateAndReason routingStateAndReason, + ActionListener listener + ) { + if (stopped) { + return; + } + trainedModelAllocationService.updateModelAllocationState( + new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, routingStateAndReason), + ActionListener.wrap(success -> { + logger.debug( + () -> new ParameterizedMessage("[{}] model is [{}] and master notified", modelId, routingStateAndReason.getState()) + ); + listener.onResponse(AcknowledgedResponse.TRUE); + }, + error -> { + logger.warn( + () -> new ParameterizedMessage( + "[{}] model is [{}] but failed to notify master", + modelId, + routingStateAndReason.getState() + ), + error + ); + listener.onFailure(error); + } + ) + ); + } + + private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) { + logger.error(() -> new ParameterizedMessage("[{}] model failed to load", task.getModelId()), ex); + if (task.isStopped()) { + logger.debug(() -> new ParameterizedMessage("[{}] model failed to load, but is now stopped", task.getModelId())); + } + // TODO: Do we want to remove from the modelIdToTask map? This would cause it to be reloaded by state updates on INITIALIZING + modelIdToTask.remove(task.getModelId()); + updateStoredState( + task.getModelId(), + new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(ex).getMessage()), + ActionListener.wrap(r -> {}, e -> {}) + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java new file mode 100644 index 0000000000000..87c9415f76460 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateObserver; +import org.elasticsearch.cluster.MasterNodeChangePredicate; +import org.elasticsearch.cluster.NotMasterException; +import org.elasticsearch.cluster.coordination.FailedToCommitClusterStateException; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.node.NodeClosedException; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; + +import java.util.Objects; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TrainedModelAllocationService { + + private static final Logger logger = LogManager.getLogger(TrainedModelAllocationService.class); + + private final Client client; + private final ClusterService clusterService; + private final ThreadPool threadPool; + + public TrainedModelAllocationService(Client client, ClusterService clusterService, ThreadPool threadPool) { + this.client = new OriginSettingClient(client, ML_ORIGIN); + this.clusterService = Objects.requireNonNull(clusterService); + this.threadPool = Objects.requireNonNull(threadPool); + } + + public void updateModelAllocationState( + UpdateTrainedModelAllocationStateAction.Request request, + ActionListener listener + ) { + ClusterState currentState = clusterService.state(); + ClusterStateObserver observer = new ClusterStateObserver(currentState, clusterService, null, logger, threadPool.getThreadContext()); + Predicate changePredicate = MasterNodeChangePredicate.build(currentState); + DiscoveryNode masterNode = currentState.nodes().getMasterNode(); + if (masterNode == null) { + logger.warn( + "[{}] no master known for allocation state update [{}]", + request.getModelId(), + request.getRoutingState().getState() + ); + waitForNewMasterAndRetry(observer, UpdateTrainedModelAllocationStateAction.INSTANCE, request, listener, changePredicate); + return; + } + client.execute(UpdateTrainedModelAllocationStateAction.INSTANCE, request, ActionListener.wrap(listener::onResponse, failure -> { + if (isMasterChannelException(failure)) { + logger.info( + "[{}] master channel exception will retry on new master node for allocation state update [{}]", + request.getModelId(), + request.getRoutingState().getState() + ); + waitForNewMasterAndRetry(observer, UpdateTrainedModelAllocationStateAction.INSTANCE, request, listener, changePredicate); + return; + } + listener.onFailure(failure); + })); + } + + public void createNewModelAllocation( + StartTrainedModelDeploymentAction.TaskParams taskParams, + ActionListener listener + ) { + client.execute(CreateTrainedModelAllocationAction.INSTANCE, new CreateTrainedModelAllocationAction.Request(taskParams), listener); + } + + public void deleteModelAllocation(String modelId, ActionListener listener) { + client.execute(DeleteTrainedModelAllocationAction.INSTANCE, new DeleteTrainedModelAllocationAction.Request(modelId), listener); + } + + public void waitForAllocationCondition( + final String modelId, + final Predicate predicate, + final @Nullable TimeValue timeout, + final WaitForAllocationListener listener + ) { + final Predicate clusterStatePredicate = clusterState -> predicate.test( + TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null) + ); + + final ClusterStateObserver observer = new ClusterStateObserver(clusterService, timeout, logger, threadPool.getThreadContext()); + final ClusterState clusterState = observer.setAndGetObservedState(); + if (clusterStatePredicate.test(clusterState)) { + listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null)); + } else { + observer.waitForNextChange(new ClusterStateObserver.Listener() { + @Override + public void onNewClusterState(ClusterState state) { + listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(state, modelId).orElse(null)); + } + + @Override + public void onClusterServiceClose() { + listener.onFailure(new NodeClosedException(clusterService.localNode())); + } + + @Override + public void onTimeout(TimeValue timeout) { + listener.onTimeout(timeout); + } + }, clusterStatePredicate); + } + } + + public interface WaitForAllocationListener extends ActionListener { + default void onTimeout(TimeValue timeout) { + onFailure(new IllegalStateException("Timed out when waiting for trained model allocation after " + timeout)); + } + } + + protected void waitForNewMasterAndRetry( + ClusterStateObserver observer, + ActionType action, + ActionRequest request, + ActionListener listener, + Predicate changePredicate + ) { + observer.waitForNextChange(new ClusterStateObserver.Listener() { + @Override + public void onNewClusterState(ClusterState state) { + client.execute(action, request, listener); + } + + @Override + public void onClusterServiceClose() { + logger.warn("node closed while execution action [{}] for request [{}]", action.name(), request); + listener.onFailure(new NodeClosedException(clusterService.localNode())); + } + + @Override + public void onTimeout(TimeValue timeout) { + // we wait indefinitely for a new master + assert false; + } + }, changePredicate); + } + + private static final Class[] MASTER_CHANNEL_EXCEPTIONS = new Class[] { + NotMasterException.class, + ConnectTransportException.class, + FailedToCommitClusterStateException.class }; + + private static boolean isMasterChannelException(Exception exp) { + return org.elasticsearch.ExceptionsHelper.unwrap(exp, MASTER_CHANNEL_EXCEPTIONS) != null; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 15d236831b54a..a57a4f374b012 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; @@ -26,12 +27,9 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.IdsQueryBuilder; -import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; -import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -45,7 +43,6 @@ import java.io.IOException; import java.io.InputStream; -import java.util.Locale; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -77,17 +74,18 @@ public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry, this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); } - public void startDeployment(TrainedModelDeploymentTask task) { - doStartDeployment(task); + public void startDeployment(TrainedModelDeploymentTask task, ActionListener listener) { + doStartDeployment(task, listener); } - private void doStartDeployment(TrainedModelDeploymentTask task) { + private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener listener) { logger.debug("[{}] Starting model deployment", task.getModelId()); ProcessContext processContext = new ProcessContext(task.getModelId(), task.getIndex(), executorServiceForProcess); - if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { - throw ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId()); + if (processContextByAllocation.putIfAbsent(task.getId(), processContext) != null) { + listener.onFailure(ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId())); + return; } String taskConfigDocId = NlpTaskConfig.documentId(task.getModelId()); @@ -95,22 +93,19 @@ private void doStartDeployment(TrainedModelDeploymentTask task) { ActionListener modelLoadedListener = ActionListener.wrap( success -> { executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get())); - - setTaskStateToStarted(task, ActionListener.wrap( - response -> logger.info("[{}] trained model loaded", task.getModelId()), - e -> failTask(task, - String.format(Locale.ROOT, "[%s] error setting task state to [%s] [%s]", - task.getModelId(), TrainedModelDeploymentState.STARTED, e)) - )); + listener.onResponse(task); }, - e -> failTask(task, - String.format(Locale.ROOT, "[%s] error loading model [%s]", task.getModelId(), e)) + listener::onFailure ); ActionListener configListener = ActionListener.wrap( searchResponse -> { if (searchResponse.getHits().getHits().length == 0) { - failTask(task, Messages.getMessage(Messages.TASK_CONFIG_NOT_FOUND, task.getModelId(), taskConfigDocId)); + listener.onFailure( + new ResourceNotFoundException( + Messages.getMessage(Messages.TASK_CONFIG_NOT_FOUND, task.getModelId(), taskConfigDocId) + ) + ); return; } @@ -121,10 +116,9 @@ private void doStartDeployment(TrainedModelDeploymentTask task) { // here, we are being called back on the searching thread, which MAY be a network thread // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility // executor. - executorServiceForDeployment.execute(() -> startAndLoad(task, processContext, modelLoadedListener)); + executorServiceForProcess.execute(() -> startAndLoad(processContext, modelLoadedListener)); }, - e -> failTask(task, - String.format(Locale.ROOT, "[%s] creating NLP task from configuration failed with error [%s]", task.getModelId(), e)) + listener::onFailure ); SearchRequest searchRequest = taskConfigSearchRequest(taskConfigDocId, task.getIndex()); @@ -151,22 +145,19 @@ NlpTaskConfig parseConfigDocLeniently(SearchHit hit) throws IOException { } } - private void startAndLoad(TrainedModelDeploymentTask task, - ProcessContext processContext, - ActionListener loadedListener) { + private void startAndLoad(ProcessContext processContext, ActionListener loadedListener) { try { processContext.startProcess(); processContext.loadModel(loadedListener); } catch (Exception e) { - failTask(task, - String.format(Locale.ROOT, "[%s] loading the model failed with error [%s]", task.getModelId(), e)); + loadedListener.onFailure(e); } } public void stopDeployment(TrainedModelDeploymentTask task) { ProcessContext processContext; synchronized (processContextByAllocation) { - processContext = processContextByAllocation.get(task.getAllocationId()); + processContext = processContextByAllocation.get(task.getId()); } if (processContext != null) { logger.info("[{}] Stopping deployment", task.getModelId()); @@ -179,7 +170,7 @@ public void stopDeployment(TrainedModelDeploymentTask task) { public void infer(TrainedModelDeploymentTask task, String input, TimeValue timeout, ActionListener listener) { - ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); + ProcessContext processContext = processContextByAllocation.get(task.getId()); if (processContext == null) { listener.onFailure(new IllegalStateException("[" + task.getModelId() + "] process context missing")); @@ -248,27 +239,6 @@ private void waitForResult(ProcessContext processContext, } } - private void setTaskStateToStarted(TrainedModelDeploymentTask task, - ActionListener> listener) { - TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( - TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); - task.updatePersistentTaskState(startedState, listener); - } - private void failTask(TrainedModelDeploymentTask task, - String reason) { - - logger.error("[{}] failed with reason [{}]", task.getModelId(), reason); - - TrainedModelDeploymentTaskState taskState = - new TrainedModelDeploymentTaskState(TrainedModelDeploymentState.FAILED, task.getAllocationId(), reason); - - task.updatePersistentTaskState(taskState, ActionListener.wrap( - persistentTask -> {}, - e -> logger.error(new ParameterizedMessage("[{}] error setting model deployment state to failed. " + - "Failure reason: [{}]", task.getModelId(), reason), e) - )); - } - class ProcessContext { private final String modelId; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 7bfd604863b6c..2a6d6ffe08e8c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -11,26 +11,40 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService; import java.util.Map; -public class TrainedModelDeploymentTask extends AllocatedPersistentTask implements StartTrainedModelDeploymentAction.TaskMatcher { +public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher { private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); private final TaskParams params; - private volatile DeploymentManager manager; + private final TrainedModelAllocationNodeService trainedModelAllocationNodeService; + private volatile boolean stopped; - public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map headers, - TaskParams taskParams) { + public TrainedModelDeploymentTask( + long id, + String type, + String action, + TaskId parentTask, + Map headers, + TaskParams taskParams, + TrainedModelAllocationNodeService trainedModelAllocationNodeService + ) { super(id, type, action, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); this.params = taskParams; + this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull( + trainedModelAllocationNodeService, + "trainedModelAllocationNodeService" + ); } public String getModelId() { @@ -47,14 +61,17 @@ public long estimateMemoryUsageBytes() { public void stop(String reason) { logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + stopped = true; + trainedModelAllocationNodeService.stopDeploymentAndNotify(this); + } - assert manager != null : "manager should not be unset when stop is called"; - manager.stopDeployment(this); - markAsCompleted(); + public void stopWithoutNotification(String reason) { + logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + stopped = true; } - public void setDeploymentManager(DeploymentManager manager) { - this.manager = manager; + public boolean isStopped() { + return stopped; } @Override @@ -64,6 +81,6 @@ protected void onCancelled() { } public void infer(String input, TimeValue timeout, ActionListener listener) { - manager.infer(this, input, timeout, listener); + trainedModelAllocationNodeService.infer(this, input, timeout, listener); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java index 8ee5ad193a13a..8342aa1d77e95 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java @@ -326,7 +326,7 @@ public static String nodeNameAndVersion(DiscoveryNode node) { return builder.toString(); } - static String nodeNameAndMlAttributes(DiscoveryNode node) { + public static String nodeNameAndMlAttributes(DiscoveryNode node) { String nodeNameOrID = nodeNameOrId(node); StringBuilder builder = new StringBuilder("{").append(nodeNameOrID).append('}'); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java new file mode 100644 index 0000000000000..31fcc1bbe50f6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java @@ -0,0 +1,670 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.NodesShutdownMetadata; +import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.MapBuilder; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.NodeLoadDetector; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.junit.Before; + +import java.util.Collections; +import java.util.Set; +import java.util.function.Function; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TrainedModelAllocationClusterServiceTests extends ESTestCase { + + private ClusterService clusterService; + private NodeLoadDetector nodeLoadDetector; + + @Before + public void setupObjects() { + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Sets.newHashSet(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + MlMemoryTracker memoryTracker = mock(MlMemoryTracker.class); + when(memoryTracker.isRecentlyRefreshed()).thenReturn(true); + nodeLoadDetector = new NodeLoadDetector(memoryTracker); + } + + public void testUpdateModelRoutingTable() { + String modelId = "existing-model"; + String nodeId = "ml-node-with-room"; + ClusterState currentState = ClusterState.builder(new ClusterName("testUpdateModelRoutingTable")) + .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())).build()) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, 10_000L)) + .addNode(modelId, nodeId) + .build() + ) + .build() + ) + .build(); + + assertThatStoppingAllocationPreventsMutation( + state -> TrainedModelAllocationClusterService.updateModelRoutingTable( + state, + new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STARTED, "")) + ), + currentState + ); + + ClusterState newState = TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STARTED, "")) + ); + assertThat( + TrainedModelAllocationMetadata.fromState(newState).getModelAllocation(modelId).getNodeRoutingTable().get(nodeId).getState(), + equalTo(RoutingState.STARTED) + ); + + expectThrows( + ResourceNotFoundException.class, + () -> TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request( + "missingNode", + modelId, + new RoutingStateAndReason(RoutingState.STARTED, "") + ) + ) + ); + expectThrows( + ResourceNotFoundException.class, + () -> TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request( + nodeId, + "missingModel", + new RoutingStateAndReason(RoutingState.STARTED, "") + ) + ) + ); + + // TEST Stopped + + // We should allow a "stopped" update on missing models and nodes as entries may have already been deleted + TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request("missingNode", modelId, new RoutingStateAndReason(RoutingState.STOPPED, "")) + ); + TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request(nodeId, "missingModel", new RoutingStateAndReason(RoutingState.STOPPED, "")) + ); + + ClusterState updateState = TrainedModelAllocationClusterService.updateModelRoutingTable( + currentState, + new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STOPPED, "")) + ); + assertThat( + TrainedModelAllocationMetadata.fromState(updateState).getModelAllocation(modelId).getNodeRoutingTable(), + not(hasKey(nodeId)) + ); + } + + public void testRemoveAllocation() { + ClusterState clusterStateWithoutAllocation = ClusterState.builder(new ClusterName("testRemoveAllocation")) + .metadata(Metadata.builder().build()) + .build(); + String modelId = "remove-allocation"; + + expectThrows( + ResourceNotFoundException.class, + () -> TrainedModelAllocationClusterService.removeAllocation(clusterStateWithoutAllocation, modelId) + ); + + ClusterState clusterStateWithAllocation = ClusterState.builder(new ClusterName("testRemoveAllocation")) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, randomNonNegativeLong())).build() + ) + .build() + ) + .build(); + assertThat(TrainedModelAllocationMetadata.fromState(clusterStateWithAllocation).getModelAllocation(modelId), is(not(nullValue()))); + + ClusterState modified = TrainedModelAllocationClusterService.removeAllocation(clusterStateWithAllocation, modelId); + assertThat(TrainedModelAllocationMetadata.fromState(modified).getModelAllocation(modelId), is(nullValue())); + + } + + public void testCreateAllocation() { + ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAllocation")) + .nodes( + DiscoveryNodes.builder() + .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("ml-node-without-room", true, 1000L)) + .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes())) + .build() + ) + .metadata(Metadata.builder().putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down"))) + .build(); + + TrainedModelAllocationClusterService trainedModelAllocationClusterService = createClusterService(); + ClusterState newState = trainedModelAllocationClusterService.createModelAllocation(currentState, newParams("new-model", 150)); + TrainedModelAllocation createdAllocation = TrainedModelAllocationMetadata.fromState(newState).getModelAllocation("new-model"); + + assertThat(createdAllocation, is(not(nullValue()))); + assertThat(createdAllocation.getNodeRoutingTable().keySet(), hasSize(2)); + assertThat(createdAllocation.getNodeRoutingTable(), hasKey("ml-node-with-room")); + assertThat(createdAllocation.getNodeRoutingTable().get("ml-node-with-room").getState(), equalTo(RoutingState.STARTING)); + assertThat(createdAllocation.getNodeRoutingTable(), hasKey("ml-node-without-room")); + assertThat(createdAllocation.getNodeRoutingTable().get("ml-node-without-room").getState(), equalTo(RoutingState.FAILED)); + assertThat( + createdAllocation.getNodeRoutingTable().get("ml-node-without-room").getReason(), + containsString("This node has insufficient available memory.") + ); + + expectThrows( + ResourceAlreadyExistsException.class, + () -> trainedModelAllocationClusterService.createModelAllocation(newState, newParams("new-model", 150)) + ); + } + + public void testAddRemoveAllocationNodes() { + ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAllocationNodes")) + .nodes( + DiscoveryNodes.builder() + .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("new-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("ml-node-without-room", true, 1000L)) + .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes())) + .build() + ) + .metadata( + Metadata.builder() + .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down")) + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 10_000)) + .addNode("model-1", "ml-node-with-room") + .updateAllocation("model-1", "ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, "")) + .addNode("model-1", "old-ml-node-with-room") + .updateAllocation("model-1", "old-ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, "")) + .addNode("model-1", "ml-node-shutting-down") + .addNewAllocation(newParams("model-2", 10_000)) + .addNode("model-2", "old-ml-node-with-room") + .updateAllocation("model-2", "old-ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, "")) + .build() + ) + ) + .build(); + TrainedModelAllocationClusterService trainedModelAllocationClusterService = createClusterService(); + + // Stopping shouldn't cause any updates + assertThatStoppingAllocationPreventsMutation( + trainedModelAllocationClusterService::addRemoveAllocationNodes, + currentState + ); + + ClusterState modified = trainedModelAllocationClusterService.addRemoveAllocationNodes(currentState); + TrainedModelAllocationMetadata trainedModelAllocationMetadata = TrainedModelAllocationMetadata.fromState(modified); + assertThat(trainedModelAllocationMetadata.modelAllocations().keySet(), hasSize(2)); + assertThat(trainedModelAllocationMetadata.modelAllocations(), allOf(hasKey("model-1"), hasKey("model-2"))); + + assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().keySet(), hasSize(3)); + assertThat( + trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable(), + allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"), hasKey("ml-node-without-room")) + ); + assertNodeState(trainedModelAllocationMetadata, "model-1", "ml-node-with-room", RoutingState.STARTED); + assertNodeState(trainedModelAllocationMetadata, "model-1", "new-ml-node-with-room", RoutingState.STARTING); + assertNodeState(trainedModelAllocationMetadata, "model-1", "ml-node-without-room", RoutingState.FAILED); + + assertThat(trainedModelAllocationMetadata.getModelAllocation("model-2").getNodeRoutingTable().keySet(), hasSize(3)); + assertThat( + trainedModelAllocationMetadata.getModelAllocation("model-2").getNodeRoutingTable(), + allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"), hasKey("ml-node-without-room")) + ); + assertNodeState(trainedModelAllocationMetadata, "model-2", "ml-node-with-room", RoutingState.STARTING); + assertNodeState(trainedModelAllocationMetadata, "model-2", "new-ml-node-with-room", RoutingState.STARTING); + assertNodeState(trainedModelAllocationMetadata, "model-2", "ml-node-without-room", RoutingState.FAILED); + } + + public void testShouldAllocateModels() { + String mlNode1 = "ml-node-with-room"; + String mlNode2 = "new-ml-node-with-room"; + DiscoveryNode mlNode1Node = buildNode(mlNode1, true, ByteSizeValue.ofGb(4).getBytes()); + DiscoveryNode mlNode2Node = buildNode(mlNode2, true, ByteSizeValue.ofGb(4).getBytes()); + ClusterState stateWithTwoNodes = ClusterState.builder(new ClusterName("testShouldAllocateModels")) + .nodes(DiscoveryNodes.builder().add(mlNode1Node).add(mlNode2Node)) + .build(); + ClusterState stateWithOneNode = ClusterState.builder(new ClusterName("testShouldAllocateModels")) + .nodes(DiscoveryNodes.builder().add(mlNode1Node)) + .build(); + ClusterState stateWithOneNodeNotMl = ClusterState.builder(new ClusterName("testShouldAllocateModels")) + .nodes(DiscoveryNodes.builder().add(mlNode1Node).add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes()))) + .build(); + + // No metadata in the new state means no allocations, so no updates + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(randomFrom(stateWithOneNodeNotMl, stateWithOneNode, stateWithTwoNodes)).build(), + ClusterState.builder(randomFrom(stateWithOneNodeNotMl, stateWithOneNode, stateWithTwoNodes)) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + + // Even with metadata changes, unless there are node changes, do nothing + ClusterState randomState = randomFrom(stateWithOneNodeNotMl, stateWithOneNode, stateWithTwoNodes); + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(randomState) + .metadata( + Metadata.builder() + .putCustom(TrainedModelAllocationMetadata.NAME, TrainedModelAllocationMetadataTests.randomInstance()) + .build() + ) + .build(), + ClusterState.builder(randomState) + .metadata( + Metadata.builder() + .putCustom(TrainedModelAllocationMetadata.NAME, TrainedModelAllocationMetadataTests.randomInstance()) + .build() + ) + .build() + ) + ), + is(false) + ); + + // If the node removed is not even an ML node, we should not attempt to re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithOneNodeNotMl) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + + // If the node removed is an ML node, but no models are allocated to it, we should not attempt to re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + + // If a new ML node is added, we should attempt to re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(true) + ); + + // If a new ML node is added, but allocation is stopping, we should not re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams(mlNode1, 100)) + .setAllocationToStopping(mlNode1) + .build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + + // If a new ML node is added, but its shutting down, don't re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(mlNode2)) + .build() + ) + .build(), + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + + // If a ML node is removed and its routed to, re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 100)) + .addNode("model-1", mlNode1) + .addNewAllocation(newParams("model-2", 100)) + .addNode("model-2", mlNode1) + .addNode("model-2", mlNode2) + .build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 100)) + .addNode("model-1", mlNode1) + .addNewAllocation(newParams("model-2", 100)) + .addNode("model-2", mlNode1) + .addNode("model-2", mlNode2) + .build() + ) + .build() + ) + .build() + ) + ), + is(true) + ); + + // If a ML node is removed and its routed to, but the allocation is stopping, don't re-allocate + assertThat( + TrainedModelAllocationClusterService.shouldAllocateModels( + new ClusterChangedEvent( + "test", + ClusterState.builder(stateWithOneNode) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams("model-1", 100)) + .addNode("model-1", mlNode1) + .addNewAllocation(newParams("model-2", 100)) + .addNode("model-2", mlNode1) + .addNode("model-2", mlNode2) + .setAllocationToStopping("model-2") + .build() + ) + .build() + ) + .build(), + ClusterState.builder(stateWithTwoNodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams("model-1", 100)) + .addNode("model-1", mlNode1) + .addNewAllocation(newParams("model-2", 100)) + .addNode("model-2", mlNode1) + .addNode("model-2", mlNode2) + .build() + ) + .build() + ) + .build() + ) + ), + is(false) + ); + } + + public void testSetAllocationToStopping() { + ClusterState clusterStateWithoutAllocation = ClusterState.builder(new ClusterName("testSetAllocationToStopping")) + .metadata(Metadata.builder().build()) + .build(); + String modelId = "stopping-allocation"; + + expectThrows( + ResourceNotFoundException.class, + () -> TrainedModelAllocationClusterService.setToStopping(clusterStateWithoutAllocation, modelId) + ); + + ClusterState clusterStateWithAllocation = ClusterState.builder(new ClusterName("testSetAllocationToStopping")) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, randomNonNegativeLong())).build() + ) + .build() + ) + .build(); + TrainedModelAllocationMetadata before = TrainedModelAllocationMetadata.fromState(clusterStateWithAllocation); + assertThat(before.getModelAllocation(modelId), is(not(nullValue()))); + assertThat(before.getModelAllocation(modelId).getAllocationState(), equalTo(AllocationState.STARTED)); + + ClusterState modified = TrainedModelAllocationClusterService.setToStopping(clusterStateWithAllocation, modelId); + assertThat( + TrainedModelAllocationMetadata.fromState(modified).getModelAllocation(modelId).getAllocationState(), + equalTo(AllocationState.STOPPING) + ); + } + + private void assertThatStoppingAllocationPreventsMutation( + Function mutationFunction, + ClusterState original + ) { + TrainedModelAllocationMetadata tempMetadata = TrainedModelAllocationMetadata.fromState(original); + if (tempMetadata.modelAllocations().isEmpty()) { + return; + } + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(original); + for (String modelId : tempMetadata.modelAllocations().keySet()) { + builder.setAllocationToStopping(modelId); + } + TrainedModelAllocationMetadata metadataWithStopping = builder.build(); + ClusterState originalWithStoppingAllocations = ClusterState.builder(original) + .metadata(Metadata.builder(original.metadata()).putCustom(TrainedModelAllocationMetadata.NAME, metadataWithStopping).build()) + .build(); + + assertThat( + "setting all allocations to stopping did not prevent mutation", + TrainedModelAllocationMetadata.fromState(mutationFunction.apply(originalWithStoppingAllocations)), + equalTo(metadataWithStopping) + ); + } + + private TrainedModelAllocationClusterService createClusterService() { + return new TrainedModelAllocationClusterService(Settings.EMPTY, clusterService, nodeLoadDetector); + } + + private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory) { + return new DiscoveryNode( + name, + name, + buildNewFakeTransportAddress(), + MapBuilder.newMapBuilder() + .put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, String.valueOf(nativeMemory)) + .put(MachineLearning.MAX_JVM_SIZE_NODE_ATTR, String.valueOf(10)) + .put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, String.valueOf(10)) + .map(), + isML ? DiscoveryNodeRole.roles() : Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE), + Version.CURRENT + ); + } + + private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) { + return new StartTrainedModelDeploymentAction.TaskParams(modelId, "test-index", modelSize); + } + + private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) { + assertThat(metadata.getModelAllocation(modelId).getNodeRoutingTable().get(nodeId).getState(), equalTo(routingState)); + } + + private static NodesShutdownMetadata shutdownMetadata(String nodeId) { + return new NodesShutdownMetadata( + Collections.singletonMap( + nodeId, + SingleNodeShutdownMetadata.builder() + .setType(SingleNodeShutdownMetadata.Type.REMOVE) + .setStartedAtMillis(randomNonNegativeLong()) + .setReason("tests") + .setNodeId(nodeId) + .build() + ) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java new file mode 100644 index 0000000000000..8bdd7390e1aa4 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReasonTests; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; +import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocationTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.is; + +public class TrainedModelAllocationMetadataTests extends AbstractSerializingTestCase { + + public static TrainedModelAllocationMetadata randomInstance() { + LinkedHashMap map = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomInt(5)) + .collect( + Collectors.toMap(Function.identity(), (k) -> TrainedModelAllocationTests.randomInstance(), (k, k1) -> k, LinkedHashMap::new) + ); + return new TrainedModelAllocationMetadata(map); + } + + @Override + protected TrainedModelAllocationMetadata doParseInstance(XContentParser parser) throws IOException { + return TrainedModelAllocationMetadata.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelAllocationMetadata::new; + } + + @Override + protected TrainedModelAllocationMetadata createTestInstance() { + return new TrainedModelAllocationMetadata(new HashMap<>()); + } + + public void testBuilderChanged_WhenAddingRemovingModel() { + TrainedModelAllocationMetadata original = randomInstance(); + String newModel = "foo_model"; + + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.Builder.fromMetadata(original); + assertThat(builder.isChanged(), is(false)); + + assertUnchanged(builder, b -> b.removeAllocation(newModel)); + assertUnchanged(builder, b -> b.updateAllocation(newModel, "foo", RoutingStateAndReasonTests.randomInstance())); + assertUnchanged(builder, b -> b.removeNode(newModel, "foo")); + + if (original.modelAllocations().isEmpty() == false) { + String randomExistingModel = randomFrom(original.modelAllocations().keySet().toArray(String[]::new)); + assertUnchanged(builder, b -> b.addNewAllocation(randomParams(randomExistingModel))); + } + + builder.addNewAllocation(new StartTrainedModelDeploymentAction.TaskParams(newModel, "test-index", randomNonNegativeLong())); + assertThat(builder.isChanged(), is(true)); + } + + public void testBuilderChanged_WhenAddingRemovingNodeFromModel() { + String newModel = "foo_model"; + TrainedModelAllocationMetadata original = TrainedModelAllocationMetadata.Builder.fromMetadata(randomInstance()) + .addNewAllocation(randomParams(newModel)) + .build(); + TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.Builder.fromMetadata(original); + assertThat(builder.isChanged(), is(false)); + + String newNode = "foo"; + if (randomBoolean()) { + builder.addNode(newModel, newNode); + } else { + builder.addFailedNode(newModel, newNode, "failure"); + } + assertThat(builder.isChanged(), is(true)); + + builder = TrainedModelAllocationMetadata.Builder.fromMetadata(builder.build()); + assertThat(builder.isChanged(), is(false)); + + builder.removeNode(newModel, newNode); + assertThat(builder.isChanged(), is(true)); + } + + private static TrainedModelAllocationMetadata.Builder assertUnchanged( + TrainedModelAllocationMetadata.Builder builder, + Function function + ) { + function.apply(builder); + assertThat(builder.isChanged(), is(false)); + return builder; + } + + private static StartTrainedModelDeploymentAction.TaskParams randomParams(String modelId) { + return new StartTrainedModelDeploymentAction.TaskParams(modelId, "test-index", randomNonNegativeLong()); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java new file mode 100644 index 0000000000000..d7c37a120cdb4 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -0,0 +1,359 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.allocation; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction; +import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Collections; + +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +public class TrainedModelAllocationNodeServiceTests extends ESTestCase { + + private static final String NODE_ID = "test-node"; + + private ClusterService clusterService; + private DeploymentManager deploymentManager; + private ThreadPool threadPool; + private TrainedModelAllocationService trainedModelAllocationService; + private TaskManager taskManager; + + @Before + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void setupObjects() { + trainedModelAllocationService = mock(TrainedModelAllocationService.class); + clusterService = mock(ClusterService.class); + threadPool = new TestThreadPool( + "TrainedModelAllocationNodeServiceTests", + new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME, 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool") + ); + taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); + deploymentManager = mock(DeploymentManager.class); + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + listener.onResponse(invocationOnMock.getArguments()[0]); + return null; + }).when(deploymentManager).startDeployment(any(), any()); + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(trainedModelAllocationService).updateModelAllocationState(any(), any()); + } + + @After + public void shutdown() throws InterruptedException { + terminate(threadPool); + } + + public void testLoadQueuedModels() { + TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService(); + + // When there are no queued models + trainedModelAllocationNodeService.loadQueuedModels(); + verify(deploymentManager, never()).startDeployment(any(), any()); + + String modelToLoad = "loading-model"; + String anotherModel = "loading-model-again"; + + // Should only load each model once + trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); + trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); + trainedModelAllocationNodeService.prepareModelToLoad(newParams(anotherModel)); + + trainedModelAllocationNodeService.loadQueuedModels(); + + ArgumentCaptor taskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAllocationStateAction.Request.class + ); + verify(deploymentManager, times(2)).startDeployment(taskCapture.capture(), any()); + verify(trainedModelAllocationService, times(2)).updateModelAllocationState(requestCapture.capture(), any()); + + assertThat(taskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); + assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); + assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); + assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + + assertThat(taskCapture.getAllValues().get(1).getModelId(), equalTo(anotherModel)); + assertThat(requestCapture.getAllValues().get(1).getModelId(), equalTo(anotherModel)); + assertThat(requestCapture.getAllValues().get(1).getNodeId(), equalTo(NODE_ID)); + assertThat(requestCapture.getAllValues().get(1).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + + // Since models are loaded, there shouldn't be any more loadings to occur + trainedModelAllocationNodeService.prepareModelToLoad(newParams(anotherModel)); + trainedModelAllocationNodeService.loadQueuedModels(); + verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService); + } + + public void testLoadQueuedModelsWhenStopped() { + TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService(); + + // When there are no queued models + String modelToLoad = "loading-model"; + + // Should only load each model once + trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); + trainedModelAllocationNodeService.stop(); + + trainedModelAllocationNodeService.loadQueuedModels(); + verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService); + } + + public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { + TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService(); + + // When there are no queued models + String modelToLoad = "loading-model"; + String stoppedModelToLoad = "stopped-loading-model"; + + // Only one model should be loaded, the other should be stopped + trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); + trainedModelAllocationNodeService.prepareModelToLoad(newParams(stoppedModelToLoad)); + trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing"); + trainedModelAllocationNodeService.loadQueuedModels(); + + assertBusy(() -> { + ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture()); + assertThat(stoppedTaskCapture.getValue().getModelId(), equalTo(stoppedModelToLoad)); + }); + ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAllocationStateAction.Request.class + ); + verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any()); + assertBusy(() -> verify(trainedModelAllocationService, times(3)).updateModelAllocationState(requestCapture.capture(), any())); + + boolean seenStopping = false; + for (int i = 0; i < 3; i++) { + UpdateTrainedModelAllocationStateAction.Request request = requestCapture.getAllValues().get(i); + assertThat(request.getNodeId(), equalTo(NODE_ID)); + if (request.getModelId().equals(stoppedModelToLoad)) { + if (seenStopping) { + assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STOPPED)); + } else { + assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STOPPING)); + seenStopping = true; + } + } else { + assertThat(request.getModelId(), equalTo(modelToLoad)); + assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STARTED)); + } + } + assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); + + verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService); + } + + public void testLoadQueuedModelsWhenOneFails() { + String modelToLoad = "loading-model"; + String failedModelToLoad = "failed-loading-model"; + withLoadFailure(failedModelToLoad); + TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService(); + + trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); + trainedModelAllocationNodeService.prepareModelToLoad(newParams(failedModelToLoad)); + + trainedModelAllocationNodeService.loadQueuedModels(); + + ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAllocationStateAction.Request.class + ); + verify(deploymentManager, times(2)).startDeployment(startTaskCapture.capture(), any()); + verify(trainedModelAllocationService, times(2)).updateModelAllocationState(requestCapture.capture(), any()); + + assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); + assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); + assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); + assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + + assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad)); + assertThat(requestCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad)); + assertThat(requestCapture.getAllValues().get(1).getNodeId(), equalTo(NODE_ID)); + assertThat(requestCapture.getAllValues().get(1).getRoutingState().getState(), equalTo(RoutingState.FAILED)); + + verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService); + } + + public void testClusterChanged() throws Exception { + final TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService(); + final DiscoveryNodes nodes = DiscoveryNodes.builder() + .localNodeId(NODE_ID) + .add( + new DiscoveryNode( + NODE_ID, + NODE_ID, + buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.roles(), + Version.CURRENT + ) + ) + .build(); + String modelOne = "model-1"; + String modelTwo = "model-2"; + String notUsedModel = "model-3"; + ClusterChangedEvent event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(new ClusterName("testClusterChanged")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams(modelOne)) + .addNode(modelOne, NODE_ID) + .addNewAllocation(newParams(modelTwo)) + .addNode(modelTwo, NODE_ID) + .addNewAllocation(newParams(notUsedModel)) + .addNode(notUsedModel, "some-other-node") + .build() + ) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + + trainedModelAllocationNodeService.clusterChanged(event); + + event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(new ClusterName("testClusterChanged")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams(modelOne)) + .addNode(modelOne, NODE_ID) + .addNewAllocation(newParams(modelTwo)) + .addNode(modelTwo, "some-other-node") + .addNewAllocation(newParams(notUsedModel)) + .addNode(notUsedModel, "some-other-node") + .build() + ) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + trainedModelAllocationNodeService.clusterChanged(event); + + trainedModelAllocationNodeService.loadQueuedModels(); + + assertBusy(() -> { + ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture()); + assertThat(stoppedTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelTwo)); + }); + ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAllocationStateAction.Request.class + ); + verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any()); + verify(trainedModelAllocationService, times(1)).updateModelAllocationState(requestCapture.capture(), any()); + + assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); + assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); + assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); + assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + + event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(new ClusterName("testClusterChanged")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAllocationMetadata.NAME, + TrainedModelAllocationMetadata.Builder.empty() + .addNewAllocation(newParams(modelOne)) + .addNode(modelOne, NODE_ID) + .build() + ) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + trainedModelAllocationNodeService.clusterChanged(event); + + trainedModelAllocationNodeService.loadQueuedModels(); + + verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void withLoadFailure(String modelId) { + doAnswer(invocationOnMock -> { + TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + if (task.getModelId().equals(modelId)) { + listener.onFailure(new ResourceNotFoundException("model node found")); + } else { + listener.onResponse(task); + } + return null; + }).when(deploymentManager).startDeployment(any(), any()); + } + + private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) { + return new StartTrainedModelDeploymentAction.TaskParams(modelId, "any-index", randomNonNegativeLong()); + } + + private TrainedModelAllocationNodeService createService() { + return new TrainedModelAllocationNodeService( + trainedModelAllocationService, + clusterService, + deploymentManager, + taskManager, + threadPool, + NODE_ID + ); + } + +} diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index b7036fe0aa9c6..a54b1b98af8a1 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -228,6 +228,9 @@ public class Constants { "cluster:internal/xpack/ml/job/finalize_job_execution", "cluster:internal/xpack/ml/job/kill/process", "cluster:internal/xpack/ml/job/update/process", + "cluster:internal/xpack/ml/model_allocation/create", + "cluster:internal/xpack/ml/model_allocation/delete", + "cluster:internal/xpack/ml/model_allocation/update", "cluster:internal/xpack/ml/reset_mode", "cluster:internal/xpack/transform/reset_mode", "cluster:monitor/allocation/explain",