Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] make allocated trained model infer requests fully cancellable #88649

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -143,11 +144,13 @@ protected void taskOperation(
TrainedModelDeploymentTask task,
ActionListener<InferTrainedModelDeploymentAction.Response> listener
) {
assert actionTask instanceof CancellableTask : "task [" + actionTask + "] not cancellable";
task.infer(
request.getDocs().get(0),
request.getUpdate(),
request.isSkipQueue(),
request.getInferenceTimeout(),
actionTask,
ActionListener.wrap(
pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ public void infer(
Map<String, Object> doc,
boolean skipQueue,
TimeValue timeout,
Task parentActionTask,
ActionListener<InferenceResults> listener
) {
deploymentManager.infer(task, config, doc, skipQueue, timeout, listener);
deploymentManager.infer(task, config, doc, skipQueue, timeout, parentActionTask, listener);
}

public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
Expand Down Expand Up @@ -237,6 +238,7 @@ public void infer(
Map<String, Object> doc,
boolean skipQueue,
TimeValue timeout,
Task parentActionTask,
ActionListener<InferenceResults> listener
) {
var processContext = getProcessContext(task, listener::onFailure);
Expand All @@ -254,6 +256,7 @@ public void infer(
config,
doc,
threadPool,
parentActionTask,
listener
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -33,6 +37,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {

private final InferenceConfig config;
private final Map<String, Object> doc;
private final Task parentActionTask;

InferencePyTorchAction(
String modelId,
Expand All @@ -42,11 +47,25 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
InferenceConfig config,
Map<String, Object> doc,
ThreadPool threadPool,
@Nullable Task parentActionTask,
ActionListener<InferenceResults> listener
) {
super(modelId, requestId, timeout, processContext, threadPool, listener);
this.config = config;
this.doc = doc;
this.parentActionTask = parentActionTask;
}

private boolean isCancelled() {
if (parentActionTask instanceof CancellableTask cancellableTask) {
try {
cancellableTask.ensureNotCancelled();
} catch (TaskCancelledException ex) {
logger.debug(() -> format("[%s] %s", getModelId(), ex.getMessage()));
return true;
}
}
return false;
}

@Override
Expand All @@ -56,12 +75,15 @@ protected void doRun() throws Exception {
logger.debug(() -> format("[%s] skipping inference on request [%s] as it has timed out", getModelId(), getRequestId()));
return;
}
if (isCancelled()) {
onFailure("inference task cancelled");
return;
}

final String requestIdStr = String.valueOf(getRequestId());
try {
// The request builder expect a list of inputs which are then batched.
// TODO batching was implemented for expected use-cases such as zero-shot
// classification but is not used here.
// TODO batching was implemented for expected use-cases such as zero-shot classification but is not used here.
List<String> text = Collections.singletonList(NlpTask.extractInput(getProcessContext().getModelInput().get(), doc));
NlpTask.Processor processor = getProcessContext().getNlpTaskProcessor().get();
processor.validateInputs(text);
Expand All @@ -74,6 +96,11 @@ protected void doRun() throws Exception {
logger.debug("[{}] [{}] input truncated", getModelId(), getRequestId());
}

// Tokenization is non-trivial, so check for cancellation one last time before sending request to the native process
if (isCancelled()) {
onFailure("inference task cancelled");
return;
}
getProcessContext().getResultProcessor()
.registerRequest(
requestIdStr,
Expand Down Expand Up @@ -109,6 +136,10 @@ private void processResult(
);
return;
}
if (isCancelled()) {
onFailure("inference task cancelled");
return;
}
InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult());
logger.debug(() -> format("[%s] processed result for request [%s]", getModelId(), getRequestId()));
onSuccess(results);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
Expand Down Expand Up @@ -132,6 +133,7 @@ public void infer(
InferenceConfigUpdate update,
boolean skipQueue,
TimeValue timeout,
Task parentActionTask,
ActionListener<InferenceResults> listener
) {
if (inferenceConfigHolder.get() == null) {
Expand All @@ -150,7 +152,15 @@ public void infer(
);
return;
}
trainedModelAssignmentNodeService.infer(this, update.apply(inferenceConfigHolder.get()), doc, skipQueue, timeout, listener);
trainedModelAssignmentNodeService.infer(
this,
update.apply(inferenceConfigHolder.get()),
doc,
skipQueue,
timeout,
parentActionTask,
listener
);
}

public Optional<ModelStats> modelStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testRejectedExecution() {
Map.of(),
false,
TimeValue.timeValueMinutes(1),
null,
ActionListener.wrap(result -> fail("unexpected success"), e -> assertThat(e, instanceOf(EsRejectedExecutionException.class)))
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
Expand All @@ -21,6 +27,7 @@
import org.junit.Before;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
Expand Down Expand Up @@ -64,7 +71,7 @@ public void testInferListenerOnlyCalledOnce() {
AtomicInteger timeoutCount = new AtomicInteger();
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);

ListenerCounter listener = new ListenerCounter();
TestListenerCounter listener = new TestListenerCounter();
InferencePyTorchAction action = new InferencePyTorchAction(
"test-model",
1,
Expand All @@ -73,6 +80,7 @@ public void testInferListenerOnlyCalledOnce() {
new PassThroughConfig(null, null, null),
Map.of(),
tp,
null,
listener
);
action.init();
Expand All @@ -93,6 +101,7 @@ public void testInferListenerOnlyCalledOnce() {
new PassThroughConfig(null, null, null),
Map.of(),
tp,
null,
listener
);
action.init();
Expand All @@ -114,6 +123,7 @@ public void testInferListenerOnlyCalledOnce() {
new PassThroughConfig(null, null, null),
Map.of(),
tp,
null,
listener
);
action.init();
Expand All @@ -134,7 +144,7 @@ public void testRunNotCalledAfterNotified() {
AtomicInteger timeoutCount = new AtomicInteger();
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);

ListenerCounter listener = new ListenerCounter();
TestListenerCounter listener = new TestListenerCounter();
{
InferencePyTorchAction action = new InferencePyTorchAction(
"test-model",
Expand All @@ -144,6 +154,7 @@ public void testRunNotCalledAfterNotified() {
new PassThroughConfig(null, null, null),
Map.of(),
tp,
null,
listener
);
action.init();
Expand All @@ -161,6 +172,7 @@ public void testRunNotCalledAfterNotified() {
new PassThroughConfig(null, null, null),
Map.of(),
tp,
null,
listener
);
action.init();
Expand All @@ -170,7 +182,49 @@ public void testRunNotCalledAfterNotified() {
}
}

static class ListenerCounter implements ActionListener<InferenceResults> {
public void testCallingRunAfterParentTaskCancellation() throws Exception {
DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
PyTorchResultProcessor resultProcessor = mock(PyTorchResultProcessor.class);
when(processContext.getResultProcessor()).thenReturn(resultProcessor);
AtomicInteger timeoutCount = new AtomicInteger();
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);
TaskManager taskManager = new TaskManager(Settings.EMPTY, tp, Set.of());
TestListenerCounter listener = new TestListenerCounter();
CancellableTask cancellableTask = (CancellableTask) taskManager.register("test_task", "testAction", new TaskAwareRequest() {
@Override
public void setParentTask(TaskId taskId) {}

@Override
public TaskId getParentTask() {
return TaskId.EMPTY_TASK_ID;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
}
});
InferencePyTorchAction action = new InferencePyTorchAction(
"test-model",
1,
TimeValue.MAX_VALUE,
processContext,
new PassThroughConfig(null, null, null),
Map.of(),
tp,
cancellableTask,
listener
);
action.init();
taskManager.cancel(cancellableTask, "test", () -> {});

action.doRun();
assertThat(listener.failureCounts, equalTo(1));
assertThat(listener.responseCounts, equalTo(0));
verify(resultProcessor, never()).registerRequest(anyString(), any());
}

static class TestListenerCounter implements ActionListener<InferenceResults> {
private int responseCounts;
private int failureCounts;

Expand Down