From 7d637b89923bbcc0f41a11884191aa1349e74c0d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 19 Oct 2021 23:57:25 +0300 Subject: [PATCH] [ML] Add queue_capacity setting to start deployment API (#79433) Adds a setting to the start trained model deployment API that allows configuring the capacity of the queueing mechanism that handles inference requests. --- .../StartTrainedModelDeploymentAction.java | 62 ++++++++++++++----- ...inedModelAllocationActionRequestTests.java | 9 +-- ...artTrainedModelDeploymentRequestTests.java | 33 ++++++++++ ...TrainedModelDeploymentTaskParamsTests.java | 3 +- .../TrainedModelAllocationTests.java | 9 ++- ...portStartTrainedModelDeploymentAction.java | 5 +- .../TrainedModelAllocationNodeService.java | 2 + .../deployment/DeploymentManager.java | 6 +- .../process/ProcessWorkerExecutorService.java | 6 +- ...RestStartTrainedModelDeploymentAction.java | 2 + ...nedModelAllocationClusterServiceTests.java | 2 +- .../TrainedModelAllocationMetadataTests.java | 3 +- ...rainedModelAllocationNodeServiceTests.java | 2 +- .../xpack/ml/job/NodeLoadDetectorTests.java | 3 +- 14 files changed, 106 insertions(+), 41 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 0aa300346c58d..8d250900a11b3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -60,6 +60,7 @@ public static class Request extends MasterNodeRequest implements ToXCon public static final ParseField WAIT_FOR = new ParseField("wait_for"); public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS; public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS; + public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY; public static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); @@ -69,6 +70,7 @@ public static class Request extends MasterNodeRequest implements ToXCon PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR); PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS); PARSER.declareInt(Request::setModelThreads, MODEL_THREADS); + PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY); } public static Request parseRequest(String modelId, XContentParser parser) { @@ -87,6 +89,7 @@ public static Request parseRequest(String modelId, XContentParser parser) { private AllocationStatus.State waitForState = AllocationStatus.State.STARTED; private int modelThreads = 1; private int inferenceThreads = 1; + private int queueCapacity = 1024; private Request() {} @@ -101,6 +104,7 @@ public Request(StreamInput in) throws IOException { waitForState = in.readEnum(AllocationStatus.State.class); modelThreads = in.readVInt(); inferenceThreads = in.readVInt(); + queueCapacity = in.readVInt(); } public final void setModelId(String modelId) { @@ -144,6 +148,14 @@ public void setInferenceThreads(int inferenceThreads) { this.inferenceThreads = inferenceThreads; } + public int getQueueCapacity() { + return queueCapacity; + } + + public void setQueueCapacity(int queueCapacity) { + this.queueCapacity = queueCapacity; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -152,6 +164,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnum(waitForState); out.writeVInt(modelThreads); out.writeVInt(inferenceThreads); + out.writeVInt(queueCapacity); } @Override @@ -162,6 +175,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(WAIT_FOR.getPreferredName(), waitForState); builder.field(MODEL_THREADS.getPreferredName(), modelThreads); builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads); + builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.endObject(); return builder; } @@ -183,12 +197,15 @@ public ActionRequestValidationException validate() { if (inferenceThreads < 1) { validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer"); } + if (queueCapacity < 1) { + validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer"); + } return validationException.validationErrors().isEmpty() ? null : validationException; } @Override public int hashCode() { - return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads); + return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity); } @Override @@ -204,7 +221,8 @@ public boolean equals(Object obj) { && Objects.equals(timeout, other.timeout) && Objects.equals(waitForState, other.waitForState) && modelThreads == other.modelThreads - && inferenceThreads == other.inferenceThreads; + && inferenceThreads == other.inferenceThreads + && queueCapacity == other.queueCapacity; } @Override @@ -226,16 +244,20 @@ public static boolean mayAllocateToNode(DiscoveryNode node) { private static final ParseField MODEL_BYTES = new ParseField("model_bytes"); public static final ParseField MODEL_THREADS = new ParseField("model_threads"); public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads"); + public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity"); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "trained_model_deployment_params", true, - a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3]) + a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4]) ); + static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID); PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES); PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS); PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY); } public static TaskParams fromXContent(XContentParser parser) { @@ -253,28 +275,22 @@ public static TaskParams fromXContent(XContentParser parser) { private final long modelBytes; private final int inferenceThreads; private final int modelThreads; + private final int queueCapacity; - public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) { + public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) { this.modelId = Objects.requireNonNull(modelId); this.modelBytes = modelBytes; - if (modelBytes < 0) { - throw new IllegalArgumentException("modelBytes must be non-negative"); - } this.inferenceThreads = inferenceThreads; - if (inferenceThreads < 1) { - throw new IllegalArgumentException(INFERENCE_THREADS + " must be positive"); - } this.modelThreads = modelThreads; - if (modelThreads < 1) { - throw new IllegalArgumentException(MODEL_THREADS + " must be positive"); - } + this.queueCapacity = queueCapacity; } public TaskParams(StreamInput in) throws IOException { this.modelId = in.readString(); - this.modelBytes = in.readVLong(); + this.modelBytes = in.readLong(); this.inferenceThreads = in.readVInt(); this.modelThreads = in.readVInt(); + this.queueCapacity = in.readVInt(); } public String getModelId() { @@ -293,9 +309,10 @@ public Version getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); - out.writeVLong(modelBytes); + out.writeLong(modelBytes); out.writeVInt(inferenceThreads); out.writeVInt(modelThreads); + out.writeVInt(queueCapacity); } @Override @@ -305,13 +322,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_BYTES.getPreferredName(), modelBytes); builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads); builder.field(MODEL_THREADS.getPreferredName(), modelThreads); + builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads); + return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity); } @Override @@ -323,7 +341,8 @@ public boolean equals(Object o) { return Objects.equals(modelId, other.modelId) && modelBytes == other.modelBytes && inferenceThreads == other.inferenceThreads - && modelThreads == other.modelThreads; + && modelThreads == other.modelThreads + && queueCapacity == other.queueCapacity; } @Override @@ -342,6 +361,15 @@ public int getInferenceThreads() { public int getModelThreads() { return modelThreads; } + + public int getQueueCapacity() { + return queueCapacity; + } + + @Override + public String toString() { + return Strings.toString(this); + } } public interface TaskMatcher { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java index 6a6fa0453ff7e..978139c44e142 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java @@ -14,14 +14,7 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire @Override protected Request createTestInstance() { - return new Request( - new StartTrainedModelDeploymentAction.TaskParams( - randomAlphaOfLength(10), - randomNonNegativeLong(), - randomIntBetween(1, 8), - randomIntBetween(1, 8) - ) - ); + return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom()); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java index 6bd27634dcf69..4c17f35025e7a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java @@ -18,6 +18,7 @@ import java.io.IOException; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -53,6 +54,9 @@ public static Request createRandom() { if (randomBoolean()) { request.setModelThreads(randomIntBetween(1, 8)); } + if (randomBoolean()) { + request.setQueueCapacity(randomIntBetween(1, 10000)); + } return request; } @@ -95,4 +99,33 @@ public void testValidate_GivenModelThreadsIsNegative() { assertThat(e, is(not(nullValue()))); assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer")); } + + public void testValidate_GivenQueueCapacityIsZero() { + Request request = createRandom(); + request.setQueueCapacity(0); + + ActionRequestValidationException e = request.validate(); + + assertThat(e, is(not(nullValue()))); + assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer")); + } + + public void testValidate_GivenQueueCapacityIsNegative() { + Request request = createRandom(); + request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1)); + + ActionRequestValidationException e = request.validate(); + + assertThat(e, is(not(nullValue()))); + assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer")); + } + + public void testDefaults() { + Request request = new Request(randomAlphaOfLength(10)); + assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20))); + assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED)); + assertThat(request.getInferenceThreads(), equalTo(1)); + assertThat(request.getModelThreads(), equalTo(1)); + assertThat(request.getQueueCapacity(), equalTo(1024)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java index 95a529d3ccc1e..c5160f96663a3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java @@ -36,7 +36,8 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() { randomAlphaOfLength(10), randomNonNegativeLong(), randomIntBetween(1, 8), - randomIntBetween(1, 8) + randomIntBetween(1, 8), + randomIntBetween(1, 10000) ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java index 473730901cac7..82ca307f0e024 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java @@ -13,9 +13,10 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests; import java.io.IOException; import java.util.List; @@ -31,9 +32,7 @@ public class TrainedModelAllocationTests extends AbstractSerializingTestCase { public static TrainedModelAllocation randomInstance() { - TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty( - new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1) - ); + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams()); List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); for (String node : nodes) { if (randomBoolean()) { @@ -249,7 +248,7 @@ private static DiscoveryNode buildNode() { } private static StartTrainedModelDeploymentAction.TaskParams randomParams() { - return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1); + return StartTrainedModelDeploymentTaskParamsTests.createRandom(); } private static void assertUnchanged( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 682fab432fbf6..f6d260c61fd2f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -26,7 +26,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.core.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -35,6 +34,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; @@ -161,7 +161,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ trainedModelConfig.getModelId(), modelBytes, request.getInferenceThreads(), - request.getModelThreads() + request.getModelThreads(), + request.getQueueCapacity() ); PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom( PersistentTasksCustomMetadata.TYPE); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java index 6db99931d8c46..a40aad7a7f075 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -332,6 +332,8 @@ TrainedModelDeploymentTask getTask(String modelId) { } void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) { + logger.debug(() -> new ParameterizedMessage("[{}] preparing to load model with task params: {}", + taskParams.getModelId(), taskParams)); TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register( TRAINED_MODEL_ALLOCATION_TASK_TYPE, TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + taskParams.getModelId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 8ae0aeb3f4dde..c19c0b8776179 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -392,7 +392,11 @@ class ProcessContext { this.task = Objects.requireNonNull(task); resultProcessor = new PyTorchResultProcessor(task.getModelId()); this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry); - this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024); + this.executorService = new ProcessWorkerExecutorService( + threadPool.getThreadContext(), + "pytorch_inference", + task.getParams().getQueueCapacity() + ); } PyTorchResultProcessor getResultProcessor() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java index f66018793a03c..ffa35a1549aaa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java @@ -45,14 +45,14 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService { /** * @param contextHolder the thread context holder * @param processName the name of the process to be used in logging - * @param queueSize the size of the queue holding operations. If an operation is added + * @param queueCapacity the capacity of the queue holding operations. If an operation is added * for execution when the queue is full a 429 error is thrown. */ @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") - public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueSize) { + public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueCapacity) { this.contextHolder = Objects.requireNonNull(contextHolder); this.processName = Objects.requireNonNull(processName); - this.queue = new LinkedBlockingQueue<>(queueSize); + this.queue = new LinkedBlockingQueue<>(queueCapacity); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index 6fa6405b461b0..f3cf06bf70684 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -23,6 +23,7 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR; @@ -59,6 +60,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient )); request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads())); request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads())); + request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity())); } return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java index 0c974d17fbce1..e6b82316b9663 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java @@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1); + return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024); } private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java index 1de97cc5991f3..ccb9b27be591e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java @@ -99,7 +99,8 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String modelId, randomNonNegativeLong(), randomIntBetween(1, 8), - randomIntBetween(1, 8) + randomIntBetween(1, 8), + randomIntBetween(1, 10000) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java index 88cd412c30147..0b8bffdf302aa 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -497,7 +497,7 @@ private void withSearchingLoadFailure(String modelId) { } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1); + return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024); } private TrainedModelAllocationNodeService createService() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index 770d11ae9dab6..c99758f27008c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -89,7 +89,8 @@ public void testNodeLoadDetection() { .addNewAllocation( "model1", TrainedModelAllocation.Builder - .empty(new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1)) + .empty(new StartTrainedModelDeploymentAction.TaskParams( + "model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024)) .addNewRoutingEntry("_node_id4") .addNewFailedRoutingEntry("_node_id2", "test") .addNewRoutingEntry("_node_id1")