From 809f097dd7f12a33a71bcfaa146128964352b1cc Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 27 Jul 2021 08:24:29 -0400 Subject: [PATCH] [ML] notify inference listeners of pytorch process crash (#75679) if the native process crashes, or is stopped, it is nice to notify the inference listeners if that has occurred. While we still don't have a way to notify the individual listeners of a failure, we should not allow listeners to timeout when possible. --- .../deployment/DeploymentManager.java | 14 ++++- .../process/PyTorchResultProcessor.java | 63 +++++++++++++++---- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 0d5b7f73bae4f..2c1469682c8ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -198,24 +198,32 @@ protected void doRun() { processor.validateInputs(input); BytesReference request = processor.getRequestBuilder().buildRequest(input, requestId); logger.trace(() -> "Inference Request "+ request.utf8ToString()); + PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.requestWritten(requestId); processContext.process.get().writeInferenceRequest(request); - - waitForResult(processContext, requestId, timeout, processor.getResultProcessor(), listener); + waitForResult(processContext, pendingResult, requestId, timeout, processor.getResultProcessor(), listener); } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e); onFailure(ExceptionsHelper.serverError("error writing to process", e)); + } finally { + processContext.resultProcessor.requestAccepted(requestId); } } }); } private void waitForResult(ProcessContext processContext, + PyTorchResultProcessor.PendingResult pendingResult, String requestId, TimeValue timeout, NlpTask.ResultProcessor inferenceResultsProcessor, ActionListener listener) { try { - PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(requestId, timeout); + PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult( + processContext.process.get(), + requestId, + pendingResult, + timeout + ); if (pyTorchResult == null) { listener.onFailure(new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index a8b40a1e46745..815b8e765c957 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -20,6 +20,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; public class PyTorchResultProcessor { @@ -29,6 +30,7 @@ public class PyTorchResultProcessor { private final String deploymentId; private volatile boolean isStopping; + private volatile boolean stoppedProcessing; private final LongSummaryStatistics summaryStatistics; public PyTorchResultProcessor(String deploymentId) { @@ -36,6 +38,14 @@ public PyTorchResultProcessor(String deploymentId) { this.summaryStatistics = new LongSummaryStatistics(); } + public PendingResult requestWritten(String requestId) { + return pendingResults.computeIfAbsent(requestId, k -> new PendingResult()); + } + + public void requestAccepted(String requestId) { + pendingResults.remove(requestId); + } + public void process(NativePyTorchProcess process) { try { Iterator iterator = process.readResults(); @@ -47,7 +57,7 @@ public void process(NativePyTorchProcess process) { if (pendingResult == null) { logger.warn(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); } else { - pendingResult.result = result; + pendingResult.result.set(result); pendingResult.latch.countDown(); } } @@ -56,7 +66,32 @@ public void process(NativePyTorchProcess process) { if (isStopping == false) { logger.error(new ParameterizedMessage("[{}] Error processing results", deploymentId), e); } + pendingResults.forEach((id, pendingResults) -> { + if (pendingResults.result.compareAndSet(null, new PyTorchResult( + id, + null, + null, + isStopping ? + "inference canceled as process is stopping" : + "inference native process died unexpectedly with failure [" + e.getMessage() + "]"))) { + pendingResults.latch.countDown(); + } + }); + pendingResults.clear(); + } finally { + pendingResults.forEach((id, pendingResults) -> { + // Only set the result if it has not already been set + if (pendingResults.result.compareAndSet(null, new PyTorchResult( + id, + null, + null, + "inference canceled as process is stopping"))) { + pendingResults.latch.countDown(); + } + }); + pendingResults.clear(); } + stoppedProcessing = true; logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId)); } @@ -73,14 +108,20 @@ private synchronized void processResult(PyTorchResult result) { } } - public PyTorchResult waitForResult(String requestId, TimeValue timeout) throws InterruptedException { - PendingResult pendingResult = pendingResults.computeIfAbsent(requestId, k -> new PendingResult()); - try { - if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) { - return pendingResult.result; - } - } finally { - pendingResults.remove(requestId); + public PyTorchResult waitForResult( + NativePyTorchProcess process, + String requestId, + PendingResult pendingResult, + TimeValue timeout + ) throws InterruptedException { + if (process == null || stoppedProcessing || process.isProcessAlive() == false) { + PyTorchResult storedResult = pendingResult.result.get(); + return storedResult == null ? + new PyTorchResult(requestId, null, null, "native process no longer started") : + storedResult; + } + if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) { + return pendingResult.result.get(); } return null; } @@ -89,8 +130,8 @@ public void stop() { isStopping = true; } - private static class PendingResult { - private volatile PyTorchResult result; + public static class PendingResult { + private final AtomicReference result = new AtomicReference<>(); private final CountDownLatch latch = new CountDownLatch(1); } }