Skip to content

Commit

Permalink
[ML] Add queue_capacity setting to start deployment API (elastic#79433)
Browse files Browse the repository at this point in the history
Adds a setting to the start trained model deployment API
that allows configuring the capacity of the queueing mechanism
that handles inference requests.
  • Loading branch information
dimitris-athanasiou authored Oct 19, 2021
1 parent 79260bc commit 7d637b8
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
public static final ParseField WAIT_FOR = new ParseField("wait_for");
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;

public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

Expand All @@ -69,6 +70,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
}

public static Request parseRequest(String modelId, XContentParser parser) {
Expand All @@ -87,6 +89,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private int modelThreads = 1;
private int inferenceThreads = 1;
private int queueCapacity = 1024;

private Request() {}

Expand All @@ -101,6 +104,7 @@ public Request(StreamInput in) throws IOException {
waitForState = in.readEnum(AllocationStatus.State.class);
modelThreads = in.readVInt();
inferenceThreads = in.readVInt();
queueCapacity = in.readVInt();
}

public final void setModelId(String modelId) {
Expand Down Expand Up @@ -144,6 +148,14 @@ public void setInferenceThreads(int inferenceThreads) {
this.inferenceThreads = inferenceThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -152,6 +164,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(waitForState);
out.writeVInt(modelThreads);
out.writeVInt(inferenceThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -162,6 +175,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(WAIT_FOR.getPreferredName(), waitForState);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}
Expand All @@ -183,12 +197,15 @@ public ActionRequestValidationException validate() {
if (inferenceThreads < 1) {
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
}
if (queueCapacity < 1) {
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
}
return validationException.validationErrors().isEmpty() ? null : validationException;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
}

@Override
Expand All @@ -204,7 +221,8 @@ public boolean equals(Object obj) {
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& modelThreads == other.modelThreads
&& inferenceThreads == other.inferenceThreads;
&& inferenceThreads == other.inferenceThreads
&& queueCapacity == other.queueCapacity;
}

@Override
Expand All @@ -226,16 +244,20 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
}

public static TaskParams fromXContent(XContentParser parser) {
Expand All @@ -253,28 +275,22 @@ public static TaskParams fromXContent(XContentParser parser) {
private final long modelBytes;
private final int inferenceThreads;
private final int modelThreads;
private final int queueCapacity;

public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
if (modelBytes < 0) {
throw new IllegalArgumentException("modelBytes must be non-negative");
}
this.inferenceThreads = inferenceThreads;
if (inferenceThreads < 1) {
throw new IllegalArgumentException(INFERENCE_THREADS + " must be positive");
}
this.modelThreads = modelThreads;
if (modelThreads < 1) {
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
}
this.queueCapacity = queueCapacity;
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelBytes = in.readVLong();
this.modelBytes = in.readLong();
this.inferenceThreads = in.readVInt();
this.modelThreads = in.readVInt();
this.queueCapacity = in.readVInt();
}

public String getModelId() {
Expand All @@ -293,9 +309,10 @@ public Version getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeVLong(modelBytes);
out.writeLong(modelBytes);
out.writeVInt(inferenceThreads);
out.writeVInt(modelThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -305,13 +322,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
}

@Override
Expand All @@ -323,7 +341,8 @@ public boolean equals(Object o) {
return Objects.equals(modelId, other.modelId)
&& modelBytes == other.modelBytes
&& inferenceThreads == other.inferenceThreads
&& modelThreads == other.modelThreads;
&& modelThreads == other.modelThreads
&& queueCapacity == other.queueCapacity;
}

@Override
Expand All @@ -342,6 +361,15 @@ public int getInferenceThreads() {
public int getModelThreads() {
return modelThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public interface TaskMatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire

@Override
protected Request createTestInstance() {
return new Request(
new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
)
);
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -53,6 +54,9 @@ public static Request createRandom() {
if (randomBoolean()) {
request.setModelThreads(randomIntBetween(1, 8));
}
if (randomBoolean()) {
request.setQueueCapacity(randomIntBetween(1, 10000));
}
return request;
}

Expand Down Expand Up @@ -95,4 +99,33 @@ public void testValidate_GivenModelThreadsIsNegative() {
assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsZero() {
Request request = createRandom();
request.setQueueCapacity(0);

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsNegative() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
}

public void testDefaults() {
Request request = new Request(randomAlphaOfLength(10));
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
assertThat(request.getInferenceThreads(), equalTo(1));
assertThat(request.getModelThreads(), equalTo(1));
assertThat(request.getQueueCapacity(), equalTo(1024));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;

import java.io.IOException;
import java.util.List;
Expand All @@ -31,9 +32,7 @@
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {

public static TrainedModelAllocation randomInstance() {
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
);
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
for (String node : nodes) {
if (randomBoolean()) {
Expand Down Expand Up @@ -249,7 +248,7 @@ private static DiscoveryNode buildNode() {
}

private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
}

private static void assertUnchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
Expand All @@ -35,6 +34,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
Expand Down Expand Up @@ -161,7 +161,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
trainedModelConfig.getModelId(),
modelBytes,
request.getInferenceThreads(),
request.getModelThreads()
request.getModelThreads(),
request.getQueueCapacity()
);
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
PersistentTasksCustomMetadata.TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ TrainedModelDeploymentTask getTask(String modelId) {
}

void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
logger.debug(() -> new ParameterizedMessage("[{}] preparing to load model with task params: {}",
taskParams.getModelId(), taskParams));
TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register(
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + taskParams.getModelId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ class ProcessContext {
this.task = Objects.requireNonNull(task);
resultProcessor = new PyTorchResultProcessor(task.getModelId());
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
this.executorService = new ProcessWorkerExecutorService(
threadPool.getThreadContext(),
"pytorch_inference",
task.getParams().getQueueCapacity()
);
}

PyTorchResultProcessor getResultProcessor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService {
/**
* @param contextHolder the thread context holder
* @param processName the name of the process to be used in logging
* @param queueSize the size of the queue holding operations. If an operation is added
* @param queueCapacity the capacity of the queue holding operations. If an operation is added
* for execution when the queue is full a 429 error is thrown.
*/
@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueSize) {
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueCapacity) {
this.contextHolder = Objects.requireNonNull(contextHolder);
this.processName = Objects.requireNonNull(processName);
this.queue = new LinkedBlockingQueue<>(queueSize);
this.queue = new LinkedBlockingQueue<>(queueCapacity);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;

Expand Down Expand Up @@ -59,6 +60,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
));
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
}

return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
Expand Down
Loading

0 comments on commit 7d637b8

Please sign in to comment.