Skip to content

Commit

Permalink
[ML] notify inference listeners of pytorch process crash (#75679)
Browse files Browse the repository at this point in the history
if the native process crashes, or is stopped, it is nice to notify the inference listeners if that has occurred.

While we still don't have a way to notify the individual listeners of a failure, we should not allow listeners
to timeout when possible.
  • Loading branch information
benwtrent authored Jul 27, 2021
1 parent 7bf112f commit 809f097
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,32 @@ protected void doRun() {
processor.validateInputs(input);
BytesReference request = processor.getRequestBuilder().buildRequest(input, requestId);
logger.trace(() -> "Inference Request "+ request.utf8ToString());
PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.requestWritten(requestId);
processContext.process.get().writeInferenceRequest(request);

waitForResult(processContext, requestId, timeout, processor.getResultProcessor(), listener);
waitForResult(processContext, pendingResult, requestId, timeout, processor.getResultProcessor(), listener);
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e);
onFailure(ExceptionsHelper.serverError("error writing to process", e));
} finally {
processContext.resultProcessor.requestAccepted(requestId);
}
}
});
}

private void waitForResult(ProcessContext processContext,
PyTorchResultProcessor.PendingResult pendingResult,
String requestId,
TimeValue timeout,
NlpTask.ResultProcessor inferenceResultsProcessor,
ActionListener<InferenceResults> listener) {
try {
PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(requestId, timeout);
PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(
processContext.process.get(),
requestId,
pendingResult,
timeout
);
if (pyTorchResult == null) {
listener.onFailure(new ElasticsearchStatusException("timeout [{}] waiting for inference result",
RestStatus.TOO_MANY_REQUESTS, timeout));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

public class PyTorchResultProcessor {

Expand All @@ -29,13 +30,22 @@ public class PyTorchResultProcessor {

private final String deploymentId;
private volatile boolean isStopping;
private volatile boolean stoppedProcessing;
private final LongSummaryStatistics summaryStatistics;

public PyTorchResultProcessor(String deploymentId) {
this.deploymentId = Objects.requireNonNull(deploymentId);
this.summaryStatistics = new LongSummaryStatistics();
}

public PendingResult requestWritten(String requestId) {
return pendingResults.computeIfAbsent(requestId, k -> new PendingResult());
}

public void requestAccepted(String requestId) {
pendingResults.remove(requestId);
}

public void process(NativePyTorchProcess process) {
try {
Iterator<PyTorchResult> iterator = process.readResults();
Expand All @@ -47,7 +57,7 @@ public void process(NativePyTorchProcess process) {
if (pendingResult == null) {
logger.warn(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId()));
} else {
pendingResult.result = result;
pendingResult.result.set(result);
pendingResult.latch.countDown();
}
}
Expand All @@ -56,7 +66,32 @@ public void process(NativePyTorchProcess process) {
if (isStopping == false) {
logger.error(new ParameterizedMessage("[{}] Error processing results", deploymentId), e);
}
pendingResults.forEach((id, pendingResults) -> {
if (pendingResults.result.compareAndSet(null, new PyTorchResult(
id,
null,
null,
isStopping ?
"inference canceled as process is stopping" :
"inference native process died unexpectedly with failure [" + e.getMessage() + "]"))) {
pendingResults.latch.countDown();
}
});
pendingResults.clear();
} finally {
pendingResults.forEach((id, pendingResults) -> {
// Only set the result if it has not already been set
if (pendingResults.result.compareAndSet(null, new PyTorchResult(
id,
null,
null,
"inference canceled as process is stopping"))) {
pendingResults.latch.countDown();
}
});
pendingResults.clear();
}
stoppedProcessing = true;
logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId));
}

Expand All @@ -73,14 +108,20 @@ private synchronized void processResult(PyTorchResult result) {
}
}

public PyTorchResult waitForResult(String requestId, TimeValue timeout) throws InterruptedException {
PendingResult pendingResult = pendingResults.computeIfAbsent(requestId, k -> new PendingResult());
try {
if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) {
return pendingResult.result;
}
} finally {
pendingResults.remove(requestId);
public PyTorchResult waitForResult(
NativePyTorchProcess process,
String requestId,
PendingResult pendingResult,
TimeValue timeout
) throws InterruptedException {
if (process == null || stoppedProcessing || process.isProcessAlive() == false) {
PyTorchResult storedResult = pendingResult.result.get();
return storedResult == null ?
new PyTorchResult(requestId, null, null, "native process no longer started") :
storedResult;
}
if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) {
return pendingResult.result.get();
}
return null;
}
Expand All @@ -89,8 +130,8 @@ public void stop() {
isStopping = true;
}

private static class PendingResult {
private volatile PyTorchResult result;
public static class PendingResult {
private final AtomicReference<PyTorchResult> result = new AtomicReference<>();
private final CountDownLatch latch = new CountDownLatch(1);
}
}

0 comments on commit 809f097

Please sign in to comment.