From 825984b1d1468f6d4f08eb9948367d31bb467d53 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 19 Aug 2020 11:23:03 -0700 Subject: [PATCH] Change to use callbacks in cold start (#208) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Change to use callbacks in cold start This PR changes the code path of cold start in the transport layer to use callbacks. Previously, I created AD’s ExecutorService that has one thread for cold starts in ColdStartRunner. When we need to trigger a cold start, we can submit a task in the ExecutorService, and consult a hash map (keyed by detector Id) that cached the results of recent cold start results. Since I have to invoke the cold start thread in various callbacks, I created a cold start thread pool and put the cold start result in the transport state. This PR also handles new exceptions like invalid queries introduced by recent changes on ModelManager and FeatureManager. This PR lowers the severity of a couple of log messages in HashRing and RCFPollingTransportAction to avoid overwhelming readers of log files. These log messages are common. This PR corrects typos and updates known causes of EndRunException in comments. Testing done: 1. Simulated cold start failures: Exceptions of cold starts can be seen by the transport layer.  EndRunException can cause AD jobs to be terminated. 2. Happy case of a cold start still works. --- codecov.yml | 1 + .../ad/AnomalyDetectorJobRunner.java | 3 +- .../ad/AnomalyDetectorPlugin.java | 7 +- .../ad/cluster/HashRing.java | 2 +- .../ad/feature/FeatureManager.java | 22 +- .../ad/ml/ModelManager.java | 4 +- .../AnomalyResultTransportAction.java | 240 ++++++--- .../transport/RCFPollingTransportAction.java | 2 +- .../ad/transport/TransportState.java | 174 +++++-- .../ad/transport/TransportStateManager.java | 110 +++-- .../handler/DetectionStateHandler.java | 4 +- .../ad/util/ColdStartRunner.java | 94 ---- .../ad/AbstractADTest.java | 15 +- .../ad/feature/FeatureManagerTests.java | 25 +- .../ad/transport/AnomalyResultTests.java | 461 ++++++++++++++---- .../ad/transport/RCFPollingTests.java | 1 - .../transport/TransportStateManagerTests.java | 12 +- .../ad/transport/TransportStateTests.java | 83 ++-- .../handler/DetectorStateHandlerTests.java | 6 +- .../ad/util/ColdStartRunnerTests.java | 90 ---- 20 files changed, 851 insertions(+), 505 deletions(-) delete mode 100644 src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java delete mode 100644 src/test/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunnerTests.java diff --git a/codecov.yml b/codecov.yml index 4922c28d..7b15d37a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -16,6 +16,7 @@ coverage: - "cli/" flags: - cli + patch: off comment: layout: "reach, diff, flags, files" behavior: default diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java index abc22f38..28348bc4 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java @@ -326,8 +326,7 @@ protected void handleAdException( } else { detectorEndRunExceptionCount.remove(detectorId); if (exception instanceof InternalFailure) { - // AnomalyResultTransportAction already prints exception stack trace - log.error("InternalFailure happened when executing anomaly result action for " + detectorId); + log.error("InternalFailure happened when executing anomaly result action for " + detectorId, exception); } else { log.error("Failed to execute anomaly result action for " + detectorId, exception); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 6f2d567a..ba49287d 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -121,7 +121,6 @@ import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyIndexHandler; import com.amazon.opendistroforelasticsearch.ad.transport.handler.DetectionStateHandler; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; -import com.amazon.opendistroforelasticsearch.ad.util.ColdStartRunner; import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; import com.amazon.opendistroforelasticsearch.ad.util.IndexUtils; import com.amazon.opendistroforelasticsearch.ad.util.Throttler; @@ -320,7 +319,6 @@ public Collection createComponents( clock, AnomalyDetectorSettings.HOURLY_MAINTENANCE ); - ColdStartRunner runner = new ColdStartRunner(); FeatureManager featureManager = new FeatureManager( searchFeatureDao, interpolator, @@ -333,7 +331,9 @@ public Collection createComponents( AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AD_THREAD_POOL_NAME ); anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); @@ -386,7 +386,6 @@ public Collection createComponents( modelManager, clock, stateManager, - runner, new ADClusterEventListener(clusterService, hashRing, modelManager, nodeFilter), adCircuitBreakerService, adStats, diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/cluster/HashRing.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/cluster/HashRing.java index b82ababe..6cccb3b1 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/cluster/HashRing.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/cluster/HashRing.java @@ -78,7 +78,7 @@ public boolean build() { // Check cooldown period if (clock.millis() - lastUpdate <= coolDownPeriod.getMillis()) { - LOG.info(COOLDOWN_MSG); + LOG.debug(COOLDOWN_MSG); return false; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java index 57390a67..64cb2cee 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java @@ -42,6 +42,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; +import org.elasticsearch.threadpool.ThreadPool; import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; @@ -72,6 +74,8 @@ public class FeatureManager { private final double previewSampleRate; private final int maxPreviewSamples; private final Duration featureBufferTtl; + private final ThreadPool threadPool; + private final String adThreadPoolName; /** * Constructor with dependencies and configuration. @@ -88,6 +92,7 @@ public class FeatureManager { * @param previewSampleRate number of samples to number of all the data points in the preview time range * @param maxPreviewSamples max number of samples from search for preview features * @param featureBufferTtl time to live for stale feature buffers + * @param threadPool object through which we can invoke different threadpool using different names */ public FeatureManager( SearchFeatureDao searchFeatureDao, @@ -101,7 +106,9 @@ public FeatureManager( int maxNeighborDistance, double previewSampleRate, int maxPreviewSamples, - Duration featureBufferTtl + Duration featureBufferTtl, + ThreadPool threadPool, + String adThreadPoolName ) { this.searchFeatureDao = searchFeatureDao; this.interpolator = interpolator; @@ -117,6 +124,8 @@ public FeatureManager( this.featureBufferTtl = featureBufferTtl; this.detectorIdsToTimeShingles = new ConcurrentHashMap<>(); + this.threadPool = threadPool; + this.adThreadPoolName = adThreadPoolName; } /** @@ -331,11 +340,10 @@ public Optional getColdStartData(AnomalyDetector detector) { * onFailure is called with EndRunException on feature query creation errors */ public void getColdStartData(AnomalyDetector detector, ActionListener> listener) { + ActionListener> latestTimeListener = ActionListener + .wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure); searchFeatureDao - .getLatestDataTime( - detector, - ActionListener.wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure) - ); + .getLatestDataTime(detector, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false)); } private void getColdStartSamples(Optional latest, AnomalyDetector detector, ActionListener> listener) { @@ -343,11 +351,13 @@ private void getColdStartSamples(Optional latest, AnomalyDetector detector if (latest.isPresent()) { List> sampleRanges = getColdStartSampleRanges(detector, latest.get()); try { + ActionListener>> getFeaturesListener = ActionListener + .wrap(samples -> processColdStartSamples(samples, shingleSize, listener), listener::onFailure); searchFeatureDao .getFeatureSamplesForPeriods( detector, sampleRanges, - ActionListener.wrap(samples -> processColdStartSamples(samples, shingleSize, listener), listener::onFailure) + new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, getFeaturesListener, false) ); } catch (IOException e) { listener.onFailure(new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true)); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index 5b5246f1..10f3bab0 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -255,7 +255,7 @@ public String getDetectorIdForModelId(String modelId) { * @param forest RCF configuration, including forest size * @param detectorId ID of the detector with no effects on partitioning * @return a pair of number of partitions and size of a parition (number of trees) - * @throws LimitExceededException when there is no sufficient resouce available + * @throws LimitExceededException when there is no sufficient resource available */ public Entry getPartitionedForestSizes(RandomCutForest forest, String detectorId) { long totalSize = estimateModelSize(forest); @@ -295,7 +295,7 @@ public Entry getPartitionedForestSizes(RandomCutForest forest, * * @param detector detector object * @return a pair of number of partitions and size of a parition (number of trees) - * @throws LimitExceededException when there is no sufficient resouce available + * @throws LimitExceededException when there is no sufficient resource available */ public Entry getPartitionedForestSizes(AnomalyDetector detector) { int shingleSize = detector.getShingleSize(); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 4a4a200a..45146188 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -19,7 +19,6 @@ import java.util.List; import java.util.Locale; import java.util.Optional; -import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -35,6 +34,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -46,11 +46,13 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; @@ -73,7 +75,6 @@ import com.amazon.opendistroforelasticsearch.ad.settings.EnabledSetting; import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; -import com.amazon.opendistroforelasticsearch.ad.util.ColdStartRunner; import com.amazon.opendistroforelasticsearch.ad.util.ExceptionUtil; public class AnomalyResultTransportAction extends HandledTransportAction { @@ -89,10 +90,10 @@ public class AnomalyResultTransportAction extends HandledTransportAction getFeatureData(double[] currentFeature, AnomalyDetector detector) { @@ -170,22 +172,24 @@ private List getFeatureData(double[] currentFeature, AnomalyDetecto * Also, AD is responsible for logging the stack trace. To avoid bloating our logs, alerting * should always just log the message of an AnomalyDetectionException exception by default. * - * Known cause of EndRunException with endNow returning false: + * Known causes of EndRunException with endNow returning false: * + training data for cold start not available * + cold start cannot succeed * + unknown prediction error * + memory circuit breaker tripped * - * Known cause of EndRunException with endNow returning true: - * + a model's memory size reached limit + * Known causes of EndRunException with endNow returning true: + * + a model partition's memory size reached limit * + models' total memory size reached limit * + Having trouble querying feature data due to * * index does not exist * * all features have been disabled + * * invalid search query * + anomaly detector is not available * + AD plugin is disabled + * + training data is invalid due to serious internal bug(s) * - * Known cause of InternalFailure: + * Known causes of InternalFailure: * + threshold model node is not available * + cluster read/write is blocked * + cold start hasn't been finished @@ -285,20 +289,12 @@ private ActionListener onFeatureResponse( } if (!featureOptional.getProcessedFeatures().isPresent()) { - stateManager.getDetectorCheckpoint(adID, ActionListener.wrap(checkpointExists -> { - if (!checkpointExists) { - LOG.info("Trigger cold start for {}", adID); - globalRunner.compute(new ColdStartJob(detector)); - } - }, exception -> { - Throwable cause = ExceptionsHelper.unwrapCause(exception); - if (cause instanceof IndexNotFoundException) { - LOG.info("Trigger cold start for {}", adID); - globalRunner.compute(new ColdStartJob(detector)); - } else { - LOG.error(String.format("Fail to get checkpoint state for %s", adID), exception); - } - })); + Optional exception = coldStartIfNoCheckPoint(detector); + if (exception.isPresent()) { + listener.onFailure(exception.get()); + return; + } + if (!featureOptional.getUnprocessedFeatures().isPresent()) { // Feature not available is common when we have data holes. Respond empty response // so that alerting will not print stack trace to avoid bloating our logs. @@ -384,6 +380,9 @@ private ActionListener onFeatureResponse( listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true)); } else if (exception instanceof IllegalArgumentException && detector.getEnabledFeatureIds().isEmpty()) { listener.onFailure(new EndRunException(adID, ALL_FEATURES_DISABLED_ERR_MSG, true)); + } else if (exception instanceof EndRunException) { + // invalid feature query + listener.onFailure(exception); } else { handleExecuteException(exception, listener, adID); } @@ -397,7 +396,7 @@ private ActionListener onFeatureResponse( * * @param failure object that may contain exceptions thrown * @param detector detector object - * @return whether cold start runs + * @return exception if AD job execution gets resource not found exception * @throws AnomalyDetectionException List of exceptions we can throw * 1. Exception from cold start: * 1). InternalFailure due to @@ -405,24 +404,38 @@ private ActionListener onFeatureResponse( * 2). EndRunException with endNow equal to false * a. training data not available * b. cold start cannot succeed + * c. invalid training data + * 3) EndRunException with endNow equal to true + * a. invalid search query * 2. LimitExceededException from one of RCF model node when the total size of the models * is more than X% of heap memory. * 3. InternalFailure wrapping ElasticsearchTimeoutException inside caused by * RCF/Threshold model node failing to get checkpoint to restore model before timeout. */ - private boolean coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) + private AnomalyDetectionException coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) throws AnomalyDetectionException { AnomalyDetectionException exp = failure.get(); - if (exp != null) { - if (exp instanceof ResourceNotFoundException) { - LOG.info("Trigger cold start for {}", detector.getDetectorId()); - globalRunner.compute(new ColdStartJob(detector)); - return true; - } else { - throw exp; + if (exp == null) { + return null; + } + + if (!(exp instanceof ResourceNotFoundException)) { + throw exp; + } + + // fetch previous cold start exception + String adID = detector.getDetectorId(); + final Optional previousException = stateManager.fetchColdStartException(adID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return (EndRunException) exception; } } - return false; + LOG.info("Trigger cold start for {}", detector.getDetectorId()); + coldStart(detector); + return previousException.orElse(new InternalFailure(adID, NO_MODEL_ERR_MSG)); } private void findException(Throwable cause, String adID, AtomicReference failure) { @@ -446,7 +459,7 @@ private void findException(Throwable cause, String adID, AtomicReference previousException = globalRunner.fetchException(adID); - - if (previousException.isPresent()) { - LOG.error("Previous exception of {}: {}", () -> adID, () -> previousException.get()); - listener.onFailure(previousException.get()); - } else { - listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); - } + AnomalyDetectionException exception = coldStartIfNoModel(failure, detector); + if (exception != null) { + listener.onFailure(exception); return; } @@ -647,8 +653,9 @@ public void onFailure(Exception e) { private void handleThresholdResult() { try { - if (coldStartIfNoModel(failure, detector)) { - listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); + AnomalyDetectionException exception = coldStartIfNoModel(failure, detector); + if (exception != null) { + listener.onFailure(exception); return; } @@ -770,39 +777,124 @@ private boolean shouldStart( return true; } - class ColdStartJob implements Callable { + private void coldStart(AnomalyDetector detector) { + String detectorId = detector.getDetectorId(); - private AnomalyDetector detector; - - ColdStartJob(AnomalyDetector detector) { - this.detector = detector; + // If last cold start is not finished, we don't trigger another one + if (stateManager.isColdStartRunning(detectorId)) { + return; } - @Override - public Boolean call() { - String detectorId = detector.getDetectorId(); - try { - Optional traingData = featureManager.getColdStartData(detector); - if (traingData.isPresent()) { - double[][] trainingPoints = traingData.get(); - modelManager.trainModel(detector, trainingPoints); - return true; - } else { - throw new EndRunException(detectorId, "Cannot get training data", false); - } + stateManager.setColdStartRunning(detectorId, true); - } catch (ElasticsearchTimeoutException timeoutEx) { - throw new InternalFailure( - detector.getDetectorId(), - "Time out while indexing cold start checkpoint or get training data", - timeoutEx - ); - } catch (EndRunException endRunEx) { - throw endRunEx; - } catch (Exception ex) { - throw new EndRunException(detector.getDetectorId(), "Error while cold start", ex, false); + ActionListener> listener = ActionListener.wrap(trainingData -> { + if (trainingData.isPresent()) { + double[][] dataPoints = trainingData.get(); + + ActionListener trainModelListener = ActionListener.wrap(res -> { + stateManager.setColdStartRunning(detectorId, false); + LOG.info("Succeeded in training {}", detectorId); + }, exception -> { + if (exception instanceof AnomalyDetectionException) { + // e.g., partitioned model exceeds memory limit + stateManager.setLastColdStartException(detectorId, (AnomalyDetectionException) exception); + } else if (exception instanceof IllegalArgumentException) { + // IllegalArgumentException due to invalid training data + stateManager + .setLastColdStartException( + detectorId, + new EndRunException(detectorId, "Invalid training data", exception, false) + ); + } else if (exception instanceof ElasticsearchTimeoutException) { + stateManager + .setLastColdStartException( + detectorId, + new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) + ); + } else { + stateManager + .setLastColdStartException( + detectorId, + new EndRunException(detectorId, "Error while training model", exception, false) + ); + } + stateManager.setColdStartRunning(detectorId, false); + }); + + modelManager + .trainModel( + detector, + dataPoints, + new ThreadedActionListener<>(LOG, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, trainModelListener, false) + ); + } else { + stateManager.setLastColdStartException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); + stateManager.setColdStartRunning(detectorId, false); + } + }, exception -> { + if (exception instanceof ElasticsearchTimeoutException) { + stateManager + .setLastColdStartException( + detectorId, + new InternalFailure(detectorId, "Time out while getting training data", exception) + ); + } else if (exception instanceof AnomalyDetectionException) { + // e.g., Invalid search query + stateManager.setLastColdStartException(detectorId, (AnomalyDetectionException) exception); + } else { + stateManager + .setLastColdStartException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); + } + stateManager.setColdStartRunning(detectorId, false); + }); + + threadPool + .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> featureManager + .getColdStartData( + detector, + new ThreadedActionListener<>(LOG, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, listener, false) + ) + ); + } + + /** + * Check if checkpoint for an detector exists or not. If not and previous + * run is not EndRunException whose endNow is true, trigger cold start. + * @param detector detector object + * @return previous cold start exception + */ + private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { + String detectorId = detector.getDetectorId(); + + Optional previousException = stateManager.fetchColdStartException(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return previousException; } } + stateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { + if (!checkpointExists) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } + }, exception -> { + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (cause instanceof IndexNotFoundException) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } else { + String errorMsg = String.format("Fail to get checkpoint state for %s", detectorId); + LOG.error(errorMsg, exception); + stateManager.setLastColdStartException(detectorId, new AnomalyDetectionException(errorMsg, exception)); + } + })); + + return previousException; } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTransportAction.java index f288534f..ec14d586 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTransportAction.java @@ -107,7 +107,7 @@ protected void doExecute(Task task, RCFPollingRequest request, ActionListener detectorDef; - // number of partitions and the number's fetch time - private Entry partitonNumber; + // detector definition + private AnomalyDetector detectorDef; + // number of partitions + private int partitonNumber; // checkpoint fetch time - private Instant checkpoint; - // last error. Used by DetectorStateHandler to check if the error for a + private Instant lastAccessTime; + // last detection error. Used by DetectorStateHandler to check if the error for a // detector has changed or not. If changed, trigger indexing. - private Entry lastError; - - public TransportState(String detectorId) { + private Optional lastDetectionError; + // last training error. Used to save cold start error by a concurrent cold start thread. + private Optional lastColdStartException; + // flag indicating whether checkpoint for the detector exists + private boolean checkPointExists; + // clock to get current time + private final Clock clock; + // cold start running flag to prevent concurrent cold start + private boolean coldStartRunning; + + public TransportState(String detectorId, Clock clock) { this.detectorId = detectorId; - detectorDef = null; - partitonNumber = null; - checkpoint = null; - lastError = null; + this.detectorDef = null; + this.partitonNumber = -1; + this.lastAccessTime = clock.instant(); + this.lastDetectionError = Optional.empty(); + this.lastColdStartException = Optional.empty(); + this.checkPointExists = false; + this.clock = clock; + this.coldStartRunning = false; } public String getDetectorId() { return detectorId; } - public Entry getDetectorDef() { + /** + * + * @return Detector configuration object + */ + public AnomalyDetector getDetectorDef() { + refreshLastUpdateTime(); return detectorDef; } - public void setDetectorDef(Entry detectorDef) { + /** + * + * @param detectorDef Detector configuration object + */ + public void setDetectorDef(AnomalyDetector detectorDef) { this.detectorDef = detectorDef; + refreshLastUpdateTime(); } - public Entry getPartitonNumber() { + /** + * + * @return RCF partition number of the detector + */ + public int getPartitonNumber() { + refreshLastUpdateTime(); return partitonNumber; } - public void setPartitonNumber(Entry partitonNumber) { + /** + * + * @param partitonNumber RCF partition number + */ + public void setPartitonNumber(int partitonNumber) { this.partitonNumber = partitonNumber; + refreshLastUpdateTime(); } - public Instant getCheckpoint() { - return checkpoint; + /** + * Used to indicate whether cold start succeeds or not + * @return whether checkpoint of models exists or not. + */ + public boolean doesCheckpointExists() { + refreshLastUpdateTime(); + return checkPointExists; } - public void setCheckpoint(Instant checkpoint) { - this.checkpoint = checkpoint; + /** + * + * @param checkpointExists mark whether checkpoint of models exists or not. + */ + public void setCheckpointExists(boolean checkpointExists) { + refreshLastUpdateTime(); + this.checkPointExists = checkpointExists; }; - public Entry getLastError() { - return lastError; + /** + * + * @return last model inference error + */ + public Optional getLastDetectionError() { + refreshLastUpdateTime(); + return lastDetectionError; + } + + /** + * + * @param lastError last model inference error + */ + public void setLastDetectionError(String lastError) { + this.lastDetectionError = Optional.ofNullable(lastError); + refreshLastUpdateTime(); + } + + /** + * + * @return last cold start exception if any + */ + public Optional getLastColdStartException() { + refreshLastUpdateTime(); + return lastColdStartException; + } + + /** + * + * @param lastColdStartError last cold start exception if any + */ + public void setLastColdStartException(AnomalyDetectionException lastColdStartError) { + this.lastColdStartException = Optional.ofNullable(lastColdStartError); + refreshLastUpdateTime(); + } + + /** + * Used to prevent concurrent cold start + * @return whether cold start is running or not + */ + public boolean isColdStartRunning() { + refreshLastUpdateTime(); + return coldStartRunning; } - public void setLastError(Entry lastError) { - this.lastError = lastError; + /** + * + * @param coldStartRunning whether cold start is running or not + */ + public void setColdStartRunning(boolean coldStartRunning) { + this.coldStartRunning = coldStartRunning; + refreshLastUpdateTime(); } - public boolean expired(Duration stateTtl, Instant now) { - boolean ans = true; - if (detectorDef != null) { - ans = ans && expired(stateTtl, now, detectorDef.getValue()); - } - if (partitonNumber != null) { - ans = ans && expired(stateTtl, now, partitonNumber.getValue()); - } - if (checkpoint != null) { - ans = ans && expired(stateTtl, now, checkpoint); - } - if (lastError != null) { - ans = ans && expired(stateTtl, now, lastError.getValue()); - } - return ans; + /** + * refresh last access time. + */ + private void refreshLastUpdateTime() { + lastAccessTime = clock.instant(); } - private boolean expired(Duration stateTtl, Instant now, Instant toCheck) { - return toCheck.plus(stateTtl).isBefore(now); + /** + * @param stateTtl time to leave for the state + * @return whether the transport state is expired + */ + public boolean expired(Duration stateTtl) { + return lastAccessTime.plus(stateTtl).isBefore(clock.instant()); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java index ce06fc8c..f38df53a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java @@ -19,10 +19,7 @@ import java.time.Clock; import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap.SimpleEntry; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -38,6 +35,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; @@ -92,24 +90,22 @@ public TransportStateManager( * @throws LimitExceededException when there is no sufficient resource available */ public int getPartitionNumber(String adID, AnomalyDetector detector) { - if (transportStates.containsKey(adID) && transportStates.get(adID).getPartitonNumber() != null) { - Entry partitonAndTime = transportStates.get(adID).getPartitonNumber(); - partitonAndTime.setValue(clock.instant()); - return partitonAndTime.getKey(); + TransportState state = transportStates.get(adID); + if (state != null && state.getPartitonNumber() > 0) { + return state.getPartitonNumber(); } int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey(); - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id)); - state.setPartitonNumber(new SimpleEntry<>(partitionNum, clock.instant())); + state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setPartitonNumber(partitionNum); return partitionNum; } public void getAnomalyDetector(String adID, ActionListener> listener) { - if (transportStates.containsKey(adID) && transportStates.get(adID).getDetectorDef() != null) { - Entry detectorAndTime = transportStates.get(adID).getDetectorDef(); - detectorAndTime.setValue(clock.instant()); - listener.onResponse(Optional.of(detectorAndTime.getKey())); + TransportState state = transportStates.get(adID); + if (state != null && state.getDetectorDef() != null) { + listener.onResponse(Optional.of(state.getDetectorDef())); } else { GetRequest request = new GetRequest(AnomalyDetector.ANOMALY_DETECTORS_INDEX, adID); clientUtil.asyncRequest(request, client::get, onGetDetectorResponse(adID, listener)); @@ -131,8 +127,8 @@ private ActionListener onGetDetectorResponse(String adID, ActionLis ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id)); - state.setDetectorDef(new SimpleEntry<>(detector, clock.instant())); + TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setDetectorDef(detector); listener.onResponse(Optional.of(detector)); } catch (Exception t) { @@ -149,8 +145,8 @@ private ActionListener onGetDetectorResponse(String adID, ActionLis * @param listener listener to handle get request */ public void getDetectorCheckpoint(String adID, ActionListener listener) { - if (transportStates.containsKey(adID) && transportStates.get(adID).getCheckpoint() != null) { - transportStates.get(adID).setCheckpoint(clock.instant()); + TransportState state = transportStates.get(adID); + if (state != null && state.doesCheckpointExists()) { listener.onResponse(Boolean.TRUE); return; } @@ -165,12 +161,8 @@ private ActionListener onGetCheckpointResponse(String adID, ActionL if (response == null || !response.isExists()) { listener.onResponse(Boolean.FALSE); } else { - TransportState state = transportStates.get(adID); - if (state == null) { - state = new TransportState(adID); - transportStates.put(adID, state); - } - state.setCheckpoint(clock.instant()); + TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setCheckpointExists(true); listener.onResponse(Boolean.TRUE); } }, listener::onFailure); @@ -196,7 +188,7 @@ public void maintenance() { String detectorId = entry.getKey(); try { TransportState state = entry.getValue(); - if (state.expired(stateTtl, clock.instant())) { + if (state.expired(stateTtl)) { transportStates.remove(detectorId); } } catch (Exception e) { @@ -239,23 +231,69 @@ public boolean hasRunningQuery(AnomalyDetector detector) { * @param adID detector id * @return last error for the detector */ - public String getLastError(String adID) { - if (transportStates.containsKey(adID) && transportStates.get(adID).getLastError() != null) { - Entry errorAndTime = transportStates.get(adID).getLastError(); - errorAndTime.setValue(clock.instant()); - return errorAndTime.getKey(); - } - - return NO_ERROR; + public String getLastDetectionError(String adID) { + return Optional.ofNullable(transportStates.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR); } /** - * Set last error of a detector + * Set last detection error of a detector * @param adID detector id * @param error error, can be null */ - public void setLastError(String adID, String error) { - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id)); - state.setLastError(new SimpleEntry<>(error, clock.instant())); + public void setLastDetectionError(String adID, String error) { + TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setLastDetectionError(error); + } + + /** + * Set last cold start error of a detector + * @param adID detector id + * @param exception exception, can be null + */ + public void setLastColdStartException(String adID, AnomalyDetectionException exception) { + TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setLastColdStartException(exception); + } + + /** + * Get last cold start exception of a detector. The method has side effect. + * We reset error after calling the method since cold start exception can stop job running. + * @param adID detector id + * @return last cold start exception for the detector + */ + public Optional fetchColdStartException(String adID) { + TransportState state = transportStates.get(adID); + if (state == null) { + return Optional.empty(); + } + + Optional exception = state.getLastColdStartException(); + // since cold start exception can stop job running, we set it to null after using it once. + exception.ifPresent(e -> setLastColdStartException(adID, null)); + return exception; + } + + /** + * Whether last cold start for the detector is running + * @param adID detector ID + * @return running or not + */ + public boolean isColdStartRunning(String adID) { + TransportState state = transportStates.get(adID); + if (state != null) { + return state.isColdStartRunning(); + } + + return false; + } + + /** + * Mark the cold start status of the detector + * @param adID detector ID + * @param running whether it is running + */ + public void setColdStartRunning(String adID, boolean running) { + TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + state.setColdStartRunning(running); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectionStateHandler.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectionStateHandler.java index d1aebbb8..4d09f6b4 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectionStateHandler.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectionStateHandler.java @@ -112,9 +112,9 @@ public DetectionStateHandler( public void saveError(String error, String detectorId) { // trigger indexing if no error recorded (e.g., this detector got enabled just now) // or the recorded error is different than this one. - if (!Objects.equal(adStateManager.getLastError(detectorId), error)) { + if (!Objects.equal(adStateManager.getLastDetectionError(detectorId), error)) { update(detectorId, new ErrorStrategy(error)); - adStateManager.setLastError(detectorId, error); + adStateManager.setLastDetectionError(detectorId, error); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java deleted file mode 100644 index 377c0918..00000000 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunner.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package com.amazon.opendistroforelasticsearch.ad.util; - -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorCompletionService; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; -import com.google.common.util.concurrent.ThreadFactoryBuilder; - -/** - * The runner allows us to have a parallel thread start cold start in the - * coordinating AD node. We can check the execution results and exceptions if - * any when cold start finishes. - * - */ -public class ColdStartRunner { - private static final Logger LOG = LogManager.getLogger(ColdStartRunner.class); - - private ExecutorService exec; - private ExecutorCompletionService runner; - - private Map currentExceptions; - - public ColdStartRunner() { - // when the thread is daemon thread, it will end immediately when the application exits. - exec = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder().setNameFormat("ad-thread-%d").setDaemon(true).build()); - this.runner = new ExecutorCompletionService(exec); - this.currentExceptions = new ConcurrentHashMap<>(); - } - - public Future compute(Callable task) { - return runner.submit(task); - } - - public void shutDown() { - exec.shutdown(); - } - - Optional checkResult() { - try { - Future result = runner.poll(); - if (result != null) { - return Optional.of(result.get()); - } - } catch (Throwable e) { - LOG.error("Could not get result", e); - Throwable cause = e.getCause(); - if (cause instanceof AnomalyDetectionException) { - AnomalyDetectionException adException = (AnomalyDetectionException) cause; - currentExceptions.put(adException.getAnomalyDetectorId(), adException); - LOG.info("added cause for {}", adException.getAnomalyDetectorId()); - } else { - LOG.error("Get an unexpected exception"); - } - } - return Optional.empty(); - } - - public Optional fetchException(String adID) { - checkResult(); - - AnomalyDetectionException ex = currentExceptions.get(adID); - if (ex != null) { - LOG.error("Found a matching exception for " + adID, ex); - return Optional.of(currentExceptions.remove(adID)); - } else { - LOG.info("Cannot find a matching exception for {}", adID); - } - return Optional.empty(); - } -} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java index dc60acf9..f15833c8 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java @@ -34,6 +34,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportInterceptor; @@ -169,11 +170,21 @@ protected void tearDownLog4jForJUnit() { } protected static void setUpThreadPool(String name) { - threadPool = new TestThreadPool(name); + threadPool = new TestThreadPool( + name, + new FixedExecutorBuilder( + Settings.EMPTY, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + 1, + 1000, + "opendistro.ad." + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + ) + ); } protected static void tearDownThreadPool() { - ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + LOG.info("tear down threadPool"); + assertTrue(ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS)); threadPool = null; } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java index 0ffecce5..a27025e2 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Matchers.any; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; @@ -44,12 +45,14 @@ import java.util.List; import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -57,6 +60,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.LinearUniformInterpolator; @@ -98,6 +102,9 @@ public class FeatureManagerTests { @Mock private TransportStateManager stateManager; + @Mock + private ThreadPool threadPool; + private FeatureManager featureManager; @Before @@ -122,6 +129,16 @@ public void setup() { intervalInMilliseconds = detectorIntervalTimeConfig.toDuration().toMillis(); Interpolator interpolator = new LinearUniformInterpolator(new SingleFeatureLinearUniformInterpolator()); + + ExecutorService executorService = mock(ExecutorService.class); + + when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + this.featureManager = spy( new FeatureManager( searchFeatureDao, @@ -135,7 +152,9 @@ public void setup() { maxNeighborDistance, previewSampleRate, maxPreviewSamples, - featureBufferTtl + featureBufferTtl, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME ) ); } @@ -226,7 +245,9 @@ public void getColdStartData_returnExpectedToListener( 1, /*maxNeighborDistance*/ previewSampleRate, maxPreviewSamples, - featureBufferTtl + featureBufferTtl, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME ) ); featureManager.getColdStartData(detector, listener); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index 454f1b0e..37a1037e 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -25,11 +25,11 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyDouble; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -50,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -82,8 +83,13 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.NodeNotConnectedException; +import org.elasticsearch.transport.RemoteTransportException; +import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportInterceptor; +import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.junit.After; @@ -91,10 +97,12 @@ import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; import test.com.amazon.opendistroforelasticsearch.ad.util.JsonDeserializer; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; import com.amazon.opendistroforelasticsearch.ad.TestHelpers; import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; @@ -122,7 +130,6 @@ import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; import com.amazon.opendistroforelasticsearch.ad.stats.suppliers.CounterSupplier; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; -import com.amazon.opendistroforelasticsearch.ad.util.ColdStartRunner; import com.amazon.opendistroforelasticsearch.ad.util.IndexUtils; import com.amazon.opendistroforelasticsearch.ad.util.Throttler; import com.google.gson.JsonElement; @@ -132,7 +139,6 @@ public class AnomalyResultTests extends AbstractADTest { private TransportService transportService; private ClusterService clusterService; private TransportStateManager stateManager; - private ColdStartRunner runner; private FeatureManager featureQuery; private ModelManager normalModelManager; private Client client; @@ -146,6 +152,7 @@ public class AnomalyResultTests extends AbstractADTest { private String featureName; private ADCircuitBreakerService adCircuitBreakerService; private ADStats adStats; + private int partitionNum; @BeforeClass public static void setUpBeforeClass() { @@ -165,12 +172,12 @@ public void setUp() throws Exception { super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); setupTestNodes(settings); - runner = new ColdStartRunner(); transportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; stateManager = mock(TransportStateManager.class); // return 2 RCF partitions - when(stateManager.getPartitionNumber(any(String.class), any(AnomalyDetector.class))).thenReturn(2); + partitionNum = 2; + when(stateManager.getPartitionNumber(any(String.class), any(AnomalyDetector.class))).thenReturn(partitionNum); when(stateManager.isMuted(any(String.class))).thenReturn(false); detector = mock(AnomalyDetector.class); @@ -182,7 +189,9 @@ public void setUp() throws Exception { List userIndex = new ArrayList<>(); userIndex.add("test*"); when(detector.getIndices()).thenReturn(userIndex); - when(detector.getDetectorId()).thenReturn("testDetectorId"); + adID = "123"; + when(detector.getDetectorId()).thenReturn(adID); + // when(detector.getDetectorId()).thenReturn("testDetectorId"); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); @@ -213,7 +222,7 @@ public void setUp() throws Exception { return null; }).when(normalModelManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(normalModelManager.combineRcfResults(any())).thenReturn(new CombinedRcfResult(0, 1.0d)); - adID = "123"; + rcfModelID = "123-rcf-1"; when(normalModelManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); thresholdModelID = "123-threshold"; @@ -282,8 +291,6 @@ public void setUp() throws Exception { @After public final void tearDown() throws Exception { tearDownTestNodes(); - runner.shutDown(); - runner = null; client = null; super.tearDownLog4jForJUnit(); super.tearDown(); @@ -309,14 +316,14 @@ public void testNormal() throws IOException { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -337,57 +344,120 @@ private void assertAnomalyResultResponse(AnomalyResultResponse response, double assertEquals(featureName, responseFeature.getFeatureName()); } - public Throwable noModelExceptionTemplate( + /** + * Create handler that would return a failure + * @param handler callback handler + * @return handler that would return a failure + */ + private TransportResponseHandler rcfFailureHandler( + TransportResponseHandler handler, + Exception exception + ) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + public void handleResponse(T response) { + handler.handleException(new RemoteTransportException("test", exception)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + public void noModelExceptionTemplate( Exception thrownException, - ColdStartRunner globalRunner, String adID, Class expectedExceptionType, String error ) { + TransportInterceptor failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (RCFResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfFailureHandler(handler, thrownException)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + ModelManager rcfManager = mock(ModelManager.class); - doThrow(thrownException).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class)); when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); - doNothing().when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class)); - when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.of(new double[][] { { 0 } })); + // need to close nodes created in the setUp nodes and create new nodes + // for the failure interceptor. Otherwise, we will get thread leak error. + tearDownTestNodes(); + setupTestNodes(Settings.EMPTY, failureTransportInterceptor); - // These constructors register handler in transport service - new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService); - new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + // register handler on testNodes[1] + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService + ); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), - transportService, + realTransportService, settings, stateManager, - globalRunner, featureQuery, normalModelManager, hashRing, - clusterService, + realClusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - return assertException(listener, expectedExceptionType); + Throwable exception = assertException(listener, expectedExceptionType); + assertTrue("actual message: " + exception.getMessage(), exception.getMessage().contains(error)); } - public void noModelExceptionTemplate(Exception exception, ColdStartRunner globalRunner, String adID, String error) { - noModelExceptionTemplate(exception, globalRunner, adID, exception.getClass(), error); + public void noModelExceptionTemplate(Exception exception, String adID, String error) { + noModelExceptionTemplate(exception, adID, exception.getClass(), error); } public void testNormalColdStart() { noModelExceptionTemplate( new ResourceNotFoundException(adID, ""), - runner, adID, - AnomalyDetectionException.class, + InternalFailure.class, AnomalyResultTransportAction.NO_MODEL_ERR_MSG ); } @@ -395,7 +465,6 @@ public void testNormalColdStart() { public void testNormalColdStartRemoteException() { noModelExceptionTemplate( new NotSerializableExceptionWrapper(new ResourceNotFoundException(adID, "")), - runner, adID, AnomalyDetectionException.class, AnomalyResultTransportAction.NO_MODEL_ERR_MSG @@ -403,21 +472,14 @@ public void testNormalColdStartRemoteException() { } public void testNullPointerExceptionWhenRCF() { - noModelExceptionTemplate( - new NullPointerException(), - runner, - adID, - AnomalyDetectionException.class, - AnomalyResultTransportAction.NO_MODEL_ERR_MSG - ); + noModelExceptionTemplate(new NullPointerException(), adID, EndRunException.class, AnomalyResultTransportAction.BUG_RESPONSE); } public void testADExceptionWhenColdStart() { String error = "blah"; - ColdStartRunner mockRunner = mock(ColdStartRunner.class); - when(mockRunner.fetchException(any(String.class))).thenReturn(Optional.of(new AnomalyDetectionException(adID, error))); + when(stateManager.fetchColdStartException(any(String.class))).thenReturn(Optional.of(new AnomalyDetectionException(adID, error))); - noModelExceptionTemplate(new AnomalyDetectionException(adID, ""), mockRunner, adID, error); + noModelExceptionTemplate(new ResourceNotFoundException(adID, ""), adID, AnomalyDetectionException.class, error); } @SuppressWarnings("unchecked") @@ -429,8 +491,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); - ColdStartRunner mockRunner = mock(ColdStartRunner.class); - when(mockRunner.fetchException(any(String.class))) + when(stateManager.fetchColdStartException(any(String.class))) .thenReturn(Optional.of(new LimitExceededException(adID, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))); // These constructors register handler in transport service @@ -442,14 +503,14 @@ public void testInsufficientCapacityExceptionDuringColdStart() { transportService, settings, stateManager, - mockRunner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -476,14 +537,14 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -493,42 +554,105 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { assertException(listener, LimitExceededException.class); } - public void testThresholdException() { + private TransportResponseHandler rcfResponseHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } - ModelManager exceptionThreadholdfManager = mock(ModelManager.class); - doThrow(NullPointerException.class) - .when(exceptionThreadholdfManager) - .getThresholdingResult(any(String.class), any(String.class), anyDouble()); + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new RCFResultResponse(1, 1, 100)); + } - // These constructors register handler in transport service - new RCFResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - normalModelManager, - adCircuitBreakerService - ); - new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, exceptionThreadholdfManager); + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + public void thresholdExceptionTestTemplate( + Exception thrownException, + String adID, + Class expectedExceptionType, + String error + ) { + + TransportInterceptor failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (ThresholdResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfFailureHandler(handler, thrownException)); + } else if (RCFResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfResponseHandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + ModelManager rcfManager = mock(ModelManager.class); + when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); + + // need to close nodes created in the setUp nodes and create new nodes + // for the failure interceptor. Otherwise, we will get thread leak error. + tearDownTestNodes(); + setupTestNodes(Settings.EMPTY, failureTransportInterceptor); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + // register handlers on testNodes[1] + ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); + new RCFResultTransportAction(actionFilters, testNodes[1].transportService, normalModelManager, adCircuitBreakerService); + new ThresholdResultTransportAction(actionFilters, testNodes[1].transportService, normalModelManager); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), - transportService, + realTransportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, - clusterService, + realClusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - assertException(listener, AnomalyDetectionException.class); + Throwable exception = assertException(listener, expectedExceptionType); + assertTrue("actual message: " + exception.getMessage(), exception.getMessage().contains(error)); + } + + public void testThresholdException() { + thresholdExceptionTestTemplate(new NullPointerException(), adID, EndRunException.class, AnomalyResultTransportAction.BUG_RESPONSE); } public void testCircuitBreaker() { @@ -545,14 +669,14 @@ public void testCircuitBreaker() { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, breakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -609,14 +733,14 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, exceptionTransportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -674,14 +798,14 @@ public void testMute() { transportService, settings, muteStateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); @@ -706,14 +830,14 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); TransportRequestOptions option = TransportRequestOptions @@ -858,14 +982,14 @@ public void testOnFailureNull() throws IOException { transportService, settings, stateManager, - new ColdStartRunner(), featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null @@ -873,73 +997,182 @@ public void testOnFailureNull() throws IOException { listener.onFailure(null); } + @SuppressWarnings("unchecked") + private void setUpColdStart(ThreadPool mockThreadPool, boolean coldStartRunning) { + SinglePointFeatures mockSinglePoint = mock(SinglePointFeatures.class); + + when(mockSinglePoint.getProcessedFeatures()).thenReturn(Optional.empty()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockSinglePoint); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(Boolean.FALSE); + return null; + }).when(stateManager).getDetectorCheckpoint(any(String.class), any(ActionListener.class)); + + when(stateManager.isColdStartRunning(any(String.class))).thenReturn(coldStartRunning); + + ExecutorService executorService = mock(ExecutorService.class); + + when(mockThreadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + } + + @SuppressWarnings("unchecked") public void testColdStartNoTrainingData() throws Exception { - when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.empty()); + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, false); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(2)).setColdStartRunning(eq(adID), anyBoolean()); + } + + @SuppressWarnings("unchecked") + public void testConcurrentColdStart() throws Exception { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, true); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + mockThreadPool ); - AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); - expectThrows(EndRunException.class, () -> job.call()); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, never()).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, never()).setColdStartRunning(eq(adID), anyBoolean()); } + @SuppressWarnings("unchecked") public void testColdStartTimeoutPutCheckpoint() throws Exception { - when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.of(new double[][] { { 1.0 } })); - doThrow(new ElasticsearchTimeoutException("")) - .when(normalModelManager) - .trainModel(any(AnomalyDetector.class), any(double[][].class)); + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, false); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new ElasticsearchTimeoutException("")); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + mockThreadPool ); - AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); - expectThrows(InternalFailure.class, () -> job.call()); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(InternalFailure.class)); + verify(stateManager, times(2)).setColdStartRunning(eq(adID), anyBoolean()); } + @SuppressWarnings("unchecked") public void testColdStartIllegalArgumentException() throws Exception { - when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.of(new double[][] { { 1.0 } })); - doThrow(new IllegalArgumentException("")).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class)); + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, false); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("")); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + mockThreadPool ); - AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector); - expectThrows(EndRunException.class, () -> job.call()); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(2)).setColdStartRunning(eq(adID), anyBoolean()); } enum FeatureTestMode { @@ -971,14 +1204,14 @@ public void featureTestTemplate(FeatureTestMode mode) throws IOException { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1054,14 +1287,14 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1098,14 +1331,14 @@ public void testNullRCFResult() { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null @@ -1129,14 +1362,14 @@ public void testAllFeaturesDisabled() throws IOException { transportService, settings, stateManager, - runner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + threadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1148,6 +1381,9 @@ public void testAllFeaturesDisabled() throws IOException { @SuppressWarnings("unchecked") public void testEndRunDueToNoTrainingData() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, false); + ModelManager rcfManager = mock(ModelManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -1157,10 +1393,21 @@ public void testEndRunDueToNoTrainingData() { }).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); - ColdStartRunner mockRunner = mock(ColdStartRunner.class); - when(mockRunner.fetchException(any(String.class))) + when(stateManager.fetchColdStartException(any(String.class))) .thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false))); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); + // These constructors register handler in transport service new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -1170,14 +1417,14 @@ public void testEndRunDueToNoTrainingData() { transportService, settings, stateManager, - mockRunner, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, - adStats + adStats, + mockThreadPool ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1185,5 +1432,11 @@ public void testEndRunDueToNoTrainingData() { action.doExecute(null, request, listener); assertException(listener, EndRunException.class); + ArgumentCaptor booleanCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(stateManager, times(2)).setColdStartRunning(eq(adID), booleanCaptor.capture()); + List capturedBoolean = booleanCaptor.getAllValues(); + // first, we set cold start running to true; then false + assertTrue(capturedBoolean.get(0)); + assertTrue(!capturedBoolean.get(1)); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTests.java index cc8b8799..00d50346 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFPollingTests.java @@ -83,7 +83,6 @@ public class RCFPollingTests extends AbstractADTest { @BeforeClass public static void setUpBeforeClass() { setUpThreadPool(RCFPollingTests.class.getSimpleName()); - } @AfterClass diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManagerTests.java index eeacd5d5..9806417d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManagerTests.java @@ -175,9 +175,9 @@ public void testGetPartitionNumber() throws IOException, InterruptedException { public void testGetLastError() throws IOException, InterruptedException { String error = "blah"; - assertEquals(TransportStateManager.NO_ERROR, stateManager.getLastError(adId)); - stateManager.setLastError(adId, error); - assertEquals(error, stateManager.getLastError(adId)); + assertEquals(TransportStateManager.NO_ERROR, stateManager.getLastDetectionError(adId)); + stateManager.setLastDetectionError(adId, error); + assertEquals(error, stateManager.getLastDetectionError(adId)); } public void testShouldMute() { @@ -285,4 +285,10 @@ public void testMaintenanceRemove() throws IOException { ); verify(client, times(2)).get(any(), any()); } + + public void testColdStartRunning() { + assertTrue(!stateManager.isColdStartRunning(adId)); + stateManager.setColdStartRunning(adId, true); + assertTrue(stateManager.isColdStartRunning(adId)); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateTests.java index 88087d77..c4431781 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateTests.java @@ -15,82 +15,95 @@ package com.amazon.opendistroforelasticsearch.ad.transport; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import java.io.IOException; +import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.AbstractMap.SimpleImmutableEntry; import org.elasticsearch.test.ESTestCase; import com.amazon.opendistroforelasticsearch.ad.TestHelpers; +import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; public class TransportStateTests extends ESTestCase { private TransportState state; + private Clock clock; @Override public void setUp() throws Exception { super.setUp(); - state = new TransportState("123"); + clock = mock(Clock.class); + state = new TransportState("123", clock); } private Duration duration = Duration.ofHours(1); public void testMaintenanceNotRemoveSingle() throws IOException { - state - .setDetectorDef( - new SimpleImmutableEntry<>( - TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null), - Instant.ofEpochMilli(1000) - ) - ); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); - assertTrue(!state.expired(duration, Instant.MIN)); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); } public void testMaintenanceNotRemove() throws IOException { - state - .setDetectorDef( - new SimpleImmutableEntry<>( - TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null), - Instant.ofEpochSecond(1000) - ) - ); - state.setLastError(new SimpleImmutableEntry<>(null, Instant.ofEpochMilli(1000))); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + state.setLastDetectionError(null); - assertTrue(!state.expired(duration, Instant.ofEpochSecond(3700))); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); } public void testMaintenanceRemoveLastError() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); state .setDetectorDef( - new SimpleImmutableEntry<>( - TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null), - Instant.ofEpochMilli(1000) - ) + + TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null) ); - state.setLastError(new SimpleImmutableEntry<>(null, Instant.ofEpochMilli(1000))); + state.setLastDetectionError(null); - assertTrue(state.expired(duration, Instant.ofEpochSecond(3700))); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); } public void testMaintenancRemoveDetector() throws IOException { - state - .setDetectorDef( - new SimpleImmutableEntry<>(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null), Instant.MIN) - ); - assertTrue(state.expired(duration, Instant.MAX)); + when(clock.instant()).thenReturn(Instant.MIN); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + when(clock.instant()).thenReturn(Instant.MAX); + assertTrue(state.expired(duration)); } public void testMaintenanceFlagNotRemove() throws IOException { - state.setCheckpoint(Instant.ofEpochMilli(1000)); - assertTrue(!state.expired(duration, Instant.MIN)); - + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); } public void testMaintenancFlagRemove() throws IOException { - state.setCheckpoint(Instant.MIN); - assertTrue(!state.expired(duration, Instant.MIN)); + when(clock.instant()).thenReturn(Instant.MIN); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceLastColdStartRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setLastColdStartException(new AnomalyDetectionException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); + } + public void testMaintenanceLastColdStartNotRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1_000_000L)); + state.setLastColdStartException(new AnomalyDetectionException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectorStateHandlerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectorStateHandlerTests.java index 8a52b2d3..852ea1ad 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectorStateHandlerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/DetectorStateHandlerTests.java @@ -113,7 +113,7 @@ public void testBothErrorNull() { } public void testNoUpdateWitoutErrorChange() { - when(stateManager.getLastError(anyString())).thenReturn(error); + when(stateManager.getLastDetectionError(anyString())).thenReturn(error); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @SuppressWarnings("unchecked") @@ -129,7 +129,7 @@ public void testNoUpdateWitoutErrorChange() { } public void testUpdateWithErrorChange() { - when(stateManager.getLastError(anyString())).thenReturn("blah"); + when(stateManager.getLastDetectionError(anyString())).thenReturn("blah"); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @SuppressWarnings("unchecked") @@ -145,7 +145,7 @@ public void testUpdateWithErrorChange() { } public void testUpdateWithFirstChange() { - when(stateManager.getLastError(anyString())).thenReturn(TransportStateManager.NO_ERROR); + when(stateManager.getLastDetectionError(anyString())).thenReturn(TransportStateManager.NO_ERROR); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @SuppressWarnings("unchecked") diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunnerTests.java deleted file mode 100644 index 0dcb9c12..00000000 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/util/ColdStartRunnerTests.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package com.amazon.opendistroforelasticsearch.ad.util; - -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; - -import java.util.Optional; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; - -public class ColdStartRunnerTests extends ESTestCase { - private static final Logger LOG = LogManager.getLogger(ColdStartRunnerTests.class); - private ColdStartRunner runner; - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - runner = new ColdStartRunner(); - } - - @Override - @After - public void tearDown() throws Exception { - runner.shutDown(); - runner = null; - super.tearDown(); - } - - public void testNullPointerException() throws InterruptedException { - Future future = runner.compute(() -> { - LOG.info("Execute.."); - throw new NullPointerException(); - }); - - ExecutionException executionException = expectThrows(ExecutionException.class, () -> future.get()); - assertThat(executionException.getCause(), instanceOf(NullPointerException.class)); - - Optional res = runner.checkResult(); - assertThat(res.isPresent(), is(false)); - } - - public void testADException() throws InterruptedException { - - String adID = "123"; - Future future = runner.compute(() -> { - LOG.info("Execute.."); - throw new AnomalyDetectionException(adID, "blah"); - }); - - ExecutionException executionException = expectThrows(ExecutionException.class, () -> future.get()); - assertThat(executionException.getCause(), instanceOf(AnomalyDetectionException.class)); - - int retries = 10; - Optional res = null; - for (int i = 0; i < retries; i++) { - res = runner.fetchException(adID); - if (!res.isPresent()) { - // wait for ExecutorCompletionService to get the completed task - Thread.sleep(1000); - } else { - break; - } - } - - assertEquals(adID, res.get().getAnomalyDetectorId()); - } -}