From b351d409bb61453bb656ad5ac1acf067fd223f17 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Mon, 4 May 2020 15:05:28 -0700 Subject: [PATCH] Fix that AD job cannot be terminated due to missing training data (#102) We expect an EndRunException to be thrown due to missing training data. But the exception is not appropriately propagated back to AD job and results in InternalFailure instead. This PR fixes the bug. Testing done: 1. Manually reproduced all possible EndRunExceptions, check AD job is terminated, and check profile API status is correct. 2. Added unit tests to expose the bug. --- .../ad/AnomalyDetectorJobRunner.java | 2 +- .../ad/ml/ModelManager.java | 2 +- .../AnomalyResultTransportAction.java | 35 +++++---- .../ad/util/ColdStartRunner.java | 10 ++- .../ad/transport/AnomalyResultTests.java | 77 +++++++++++++++++-- 5 files changed, 100 insertions(+), 26 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java index 82f6d502..06814e93 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java @@ -320,7 +320,7 @@ protected void handleAdException( detectorEndRunExceptionCount.remove(detectorId); if (exception instanceof InternalFailure) { // AnomalyResultTransportAction already prints exception stack trace - log.error("InternalFailure happened when executed anomaly result action for " + detectorId); + log.error("InternalFailure happened when executing anomaly result action for " + detectorId); } else { log.error("Failed to execute anomaly result action for " + detectorId, exception); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index 15f18b15..ea8f08db 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -781,7 +781,7 @@ private boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { } else { throw new LimitExceededException( detectorId, - String.format("Exceeded memory limit. New size is %d and max limit is %f", total, heapLimit) + String.format("Exceeded memory limit. New size is %d bytes and max limit is %f bytes", total, heapLimit) ); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index f130ab9c..95c2ba2a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -418,15 +418,7 @@ private void findException(Throwable cause, String adID, AtomicReference previousException = globalRunner.fetchException(adID); - - if (previousException.isPresent()) { - LOG.error("Previous exception of {}: {}", () -> adID, () -> previousException.get()); - failure.set(previousException.get()); - } else { - failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); - } + failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); } else if (isException(causeException, LimitExceededException.class, LIMIT_EXCEEDED_EXCEPTION_NAME_UNDERSCORE)) { failure.set(new LimitExceededException(adID, causeException.getMessage())); } else if (causeException instanceof ElasticsearchTimeoutException) { @@ -553,7 +545,7 @@ public void onResponse(RCFResultResponse response) { @Override public void onFailure(Exception e) { try { - handlePredictionFailure(e, modelID, rcfNodeID, failure); + handlePredictionFailure(e, adID, rcfNodeID, failure); } catch (Exception ex) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { @@ -565,7 +557,20 @@ public void onFailure(Exception e) { private void handleRCFResults() { try { - if (coldStartIfNoModel(failure, detector) || rcfResults.isEmpty()) { + if (coldStartIfNoModel(failure, detector)) { + // fetch previous cold start exception + Optional previousException = globalRunner.fetchException(adID); + + if (previousException.isPresent()) { + LOG.error("Previous exception of {}: {}", () -> adID, () -> previousException.get()); + listener.onFailure(previousException.get()); + } else { + listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); + } + return; + } + + if (rcfResults.isEmpty()) { listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); return; } @@ -580,7 +585,6 @@ private void handleRCFResults() { ThresholdActionListener thresholdListener = new ThresholdActionListener( anomalyResultResponse, featureInResponse, - thresholdModelID, thresholdNodeId, detector, combinedResult, @@ -604,7 +608,6 @@ private void handleRCFResults() { class ThresholdActionListener implements ActionListener { private AtomicReference anomalyResultResponse; private List features; - private String modelID; private AtomicReference failure; private String thresholdNodeID; private ActionListener listener; @@ -615,7 +618,6 @@ class ThresholdActionListener implements ActionListener ThresholdActionListener( AtomicReference anomalyResultResponse, List features, - String modelID, String thresholdNodeID, AnomalyDetector detector, CombinedRcfResult combinedResult, @@ -624,7 +626,6 @@ class ThresholdActionListener implements ActionListener ) { this.anomalyResultResponse = anomalyResultResponse; this.features = features; - this.modelID = modelID; this.thresholdNodeID = thresholdNodeID; this.detector = detector; this.combinedResult = combinedResult; @@ -649,7 +650,7 @@ public void onResponse(ThresholdResultResponse response) { @Override public void onFailure(Exception e) { try { - handlePredictionFailure(e, modelID, thresholdNodeID, failure); + handlePredictionFailure(e, adID, thresholdNodeID, failure); } catch (Exception ex) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { @@ -807,6 +808,8 @@ public Boolean call() { "Time out while indexing cold start checkpoint or get training data", timeoutEx ); + } catch (EndRunException endRunEx) { + throw endRunEx; } catch (Exception ex) { throw new EndRunException(detector.getDetectorId(), "Error while cold start", ex, false); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java index 2b5864df..00dc757f 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java @@ -71,6 +71,9 @@ Optional checkResult() { if (cause instanceof AnomalyDetectionException) { AnomalyDetectionException adException = (AnomalyDetectionException) cause; currentExceptions.put(adException.getAnomalyDetectorId(), adException); + LOG.info("added cause for {}", adException.getAnomalyDetectorId()); + } else { + LOG.error("Get an unexpected exception"); } } return Optional.empty(); @@ -79,9 +82,12 @@ Optional checkResult() { public Optional fetchException(String adID) { checkResult(); - if (currentExceptions.containsKey(adID)) { - LOG.error("Found matching exception for {}", adID); + AnomalyDetectionException ex = currentExceptions.get(adID); + if (ex != null) { + LOG.error("Found a matching exception for " + adID, ex); return Optional.of(currentExceptions.remove(adID)); + } else { + LOG.info("Cannot find a matching exception for {}", adID); } return Optional.empty(); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index f0a929c8..7d84537e 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -57,7 +57,6 @@ import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; -import com.amazon.opendistroforelasticsearch.ad.common.exception.ClientException; import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; @@ -65,6 +64,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.ResourceNotFoundException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.feature.SinglePointFeatures; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; @@ -105,6 +105,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.NodeNotConnectedException; @@ -165,6 +166,7 @@ public void setUp() throws Exception { super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); setupTestNodes(Settings.EMPTY); FakeNode.connectNodes(testNodes); + runner = new ColdStartRunner(); transportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; stateManager = mock(ADStateManager.class); @@ -264,7 +266,6 @@ public void setupTestNodes(Settings settings) { for (int i = 0; i < testNodes.length; i++) { testNodes[i] = new FakeNode("node" + i, threadPool, settings); } - runner = new ColdStartRunner(); } @Override @@ -273,6 +274,7 @@ public final void tearDown() throws Exception { for (FakeNode testNode : testNodes) { testNode.close(); } + testNodes = null; runner.shutDown(); runner = null; client = null; @@ -887,7 +889,7 @@ public void testColdStartNoTrainingData() throws Exception { ); AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); - expectThrows(AnomalyDetectionException.class, () -> job.call()); + expectThrows(EndRunException.class, () -> job.call()); } public void testColdStartTimeoutPutCheckpoint() throws Exception { @@ -912,7 +914,30 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { ); AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); - expectThrows(ClientException.class, () -> job.call()); + expectThrows(InternalFailure.class, () -> job.call()); + } + + public void testColdStartIllegalArgumentException() throws Exception { + when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.of(new double[][] { { 1.0 } })); + doThrow(new IllegalArgumentException("")).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + stateManager, + runner, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats + ); + + AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); + expectThrows(EndRunException.class, () -> job.call()); } enum FeatureTestMode { @@ -1089,8 +1114,7 @@ public void testNullRCFResult() { @SuppressWarnings("unchecked") public void testAllFeaturesDisabled() { - // doThrow(IllegalArgumentException.class).when(featureQuery) - // .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[3]; @@ -1120,4 +1144,45 @@ public void testAllFeaturesDisabled() { assertException(listener, EndRunException.class, AnomalyResultTransportAction.ALL_FEATURES_DISABLED_ERR_MSG); } + + @SuppressWarnings("unchecked") + public void testEndRunDueToNoTrainingData() { + ModelManager rcfManager = mock(ModelManager.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onFailure(new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); + + ColdStartRunner mockRunner = mock(ColdStartRunner.class); + when(mockRunner.fetchException(any(String.class))) + .thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false))); + + // These constructors register handler in transport service + new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + stateManager, + mockRunner, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class); + } }