diff --git a/build.gradle b/build.gradle index 08c233698..51dff8b98 100644 --- a/build.gradle +++ b/build.gradle @@ -108,8 +108,8 @@ dependencies { compileOnly "org.opensearch:opensearch-job-scheduler-spi:${job_scheduler_version}" implementation "org.opensearch:common-utils:${common_utils_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" - implementation group: 'com.google.guava', name: 'guava', version:'31.0.1-jre' - implementation group: 'com.google.guava', name: 'failureaccess', version:'1.0.1' + compileOnly group: 'com.google.guava', name: 'guava', version:'31.0.1-jre' + compileOnly group: 'com.google.guava', name: 'failureaccess', version:'1.0.1' implementation group: 'org.javassist', name: 'javassist', version:'3.28.0-GA' implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' implementation group: 'com.google.code.gson', name: 'gson', version: '2.8.9' @@ -121,10 +121,6 @@ dependencies { implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc3' implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' - // we inherit jackson-core from opensearch core - implementation "com.fasterxml.jackson.core:jackson-databind:2.14.1" - implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.1" - // used for serializing/deserializing rcf models. implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0' diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index aa0ce8075..a7af702a7 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -560,14 +560,12 @@ private void stopAdJob(String detectorId, AnomalyDetectorFunction function) { }, exception -> { log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception); })); } else { log.info("AD Job was disabled for " + detectorId); - // function.execute(); } } catch (IOException e) { log.error("JobRunner failed to stop detector job " + detectorId, e); } } else { log.info("AD Job was not found for " + detectorId); - // function.execute(); } }, exception -> log.error("JobRunner failed to get detector job " + detectorId, exception)); diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java index cace64e3f..1ec27253a 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java @@ -762,7 +762,10 @@ public PooledObject wrap(LinkedBuffer obj) { adTaskManager, nodeFilter, threadPool, - client + client, + stateManager, + adTaskCacheManager, + AnomalyDetectorSettings.NUM_MIN_SAMPLES ); // return objects used by Guice to inject dependencies for e.g., diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index d47494e5a..bb555d0c9 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -40,7 +40,6 @@ import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.model.DetectorState; @@ -64,8 +63,6 @@ import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilder; @@ -451,14 +448,15 @@ private void profileMultiEntityDetectorStateRelated( if (profileResponse.getTotalUpdates() < requiredSamples) { // need to double check since what ProfileResponse returns is the highest priority entity currently in memory, but // another entity might have already been initialized and sit somewhere else (in memory or on disk). - confirmMultiEntityDetectorInitStatus( - detector, - job.getEnabledTime().toEpochMilli(), - profileBuilder, - profilesToCollect, - profileResponse.getTotalUpdates(), - listener - ); + long enabledTime = job.getEnabledTime().toEpochMilli(); + long totalUpdates = profileResponse.getTotalUpdates(); + ProfileUtil + .confirmDetectorRealtimeInitStatus( + detector, + enabledTime, + client, + onInittedEver(enabledTime, profileBuilder, profilesToCollect, detector, totalUpdates, listener) + ); } else { createRunningStateAndInitProgress(profilesToCollect, profileBuilder); listener.onResponse(profileBuilder.build()); @@ -471,18 +469,6 @@ private void profileMultiEntityDetectorStateRelated( } } - private void confirmMultiEntityDetectorInitStatus( - AnomalyDetector detector, - long enabledTime, - DetectorProfile.Builder profile, - Set profilesToCollect, - long totalUpdates, - MultiResponsesDelegateActionListener listener - ) { - SearchRequest searchLatestResult = createInittedEverRequest(detector.getDetectorId(), enabledTime, detector.getResultIndex()); - client.search(searchLatestResult, onInittedEver(enabledTime, profile, profilesToCollect, detector, totalUpdates, listener)); - } - private ActionListener onInittedEver( long lastUpdateTimeMs, DetectorProfile.Builder profileBuilder, @@ -602,26 +588,4 @@ private void processInitResponse( listener.onResponse(builder.build()); } - - /** - * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time - * @param detectorId detector id - * @param enabledTime the time when AD job is enabled in milliseconds - * @return the search request - */ - private SearchRequest createInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { - BoolQueryBuilder filterQuery = new BoolQueryBuilder(); - filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); - filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.EXECUTION_END_TIME_FIELD).gte(enabledTime)); - filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); - - SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); - - SearchRequest request = new SearchRequest(CommonName.ANOMALY_RESULT_INDEX_ALIAS); - request.source(source); - if (resultIndex != null) { - request.indices(resultIndex); - } - return request; - } } diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java index 19710f0cb..5267b6b70 100644 --- a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -22,6 +22,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.common.exception.ResourceNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; @@ -32,6 +34,7 @@ import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.model.FeatureData; import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.ProfileAction; @@ -40,10 +43,12 @@ import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; +import org.opensearch.search.SearchHits; import org.opensearch.threadpool.ThreadPool; public class ExecuteADResultResponseRecorder { @@ -55,6 +60,9 @@ public class ExecuteADResultResponseRecorder { private DiscoveryNodeFilterer nodeFilter; private ThreadPool threadPool; private Client client; + private NodeStateManager nodeStateManager; + private ADTaskCacheManager adTaskCacheManager; + private int rcfMinSamples; public ExecuteADResultResponseRecorder( AnomalyDetectionIndices anomalyDetectionIndices, @@ -62,7 +70,10 @@ public ExecuteADResultResponseRecorder( ADTaskManager adTaskManager, DiscoveryNodeFilterer nodeFilter, ThreadPool threadPool, - Client client + Client client, + NodeStateManager nodeStateManager, + ADTaskCacheManager adTaskCacheManager, + int rcfMinSamples ) { this.anomalyDetectionIndices = anomalyDetectionIndices; this.anomalyResultHandler = anomalyResultHandler; @@ -70,6 +81,9 @@ public ExecuteADResultResponseRecorder( this.nodeFilter = nodeFilter; this.threadPool = threadPool; this.client = client; + this.nodeStateManager = nodeStateManager; + this.adTaskCacheManager = adTaskCacheManager; + this.rcfMinSamples = rcfMinSamples; } public void indexAnomalyResult( @@ -185,27 +199,66 @@ private void updateLatestRealtimeTask( String error ) { // Don't need info as this will be printed repeatedly in each interval - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( + ActionListener listener = ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. + log.error("Can't find latest realtime task of detector " + detectorId); + adTaskManager.removeRealtimeTaskCache(detectorId); + } else { + log.error("Failed to update latest realtime task for detector " + detectorId, e); + } + }); + + // rcfTotalUpdates is null when we save exception messages + if (!adTaskCacheManager.hasQueriedResultIndex(detectorId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { + // confirm the total updates number since it is possible that we have already had results after job enabling time + // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. + confirmTotalRCFUpdatesFound( detectorId, taskState, rcfTotalUpdates, detectorIntervalInMinutes, error, - ActionListener.wrap(r -> { - if (r != null) { - log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); - } - }, e -> { - if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { - // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. - log.error("Can't find latest realtime task of detector " + detectorId); - adTaskManager.removeRealtimeTaskCache(detectorId); - } else { - log.error("Failed to update latest realtime task for detector " + detectorId, e); - } - }) + ActionListener + .wrap( + r -> adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + r, + detectorIntervalInMinutes, + error, + listener + ), + e -> { + log.error("Fail to confirm rcf update", e); + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + listener + ); + } + ) ); + } else { + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + listener + ); + } } /** @@ -285,4 +338,53 @@ public void indexAnomalyResultException( } } + private void confirmTotalRCFUpdatesFound( + String detectorId, + String taskState, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + String error, + ActionListener listener + ) { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + listener.onFailure(new AnomalyDetectionException(detectorId, "fail to get detector")); + return; + } + nodeStateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(jobOptional -> { + if (!jobOptional.isPresent()) { + listener.onFailure(new AnomalyDetectionException(detectorId, "fail to get job")); + return; + } + + ProfileUtil + .confirmDetectorRealtimeInitStatus( + detectorOptional.get(), + jobOptional.get().getEnabledTime().toEpochMilli(), + client, + ActionListener.wrap(searchResponse -> { + ActionListener.completeWith(listener, () -> { + SearchHits hits = searchResponse.getHits(); + Long correctedTotalUpdates = rcfTotalUpdates; + if (hits.getTotalHits().value > 0L) { + // correct the number if we have already had results after job enabling time + // so that the detector won't stay initialized + correctedTotalUpdates = Long.valueOf(rcfMinSamples); + } + adTaskCacheManager.markResultIndexQueried(detectorId); + return correctedTotalUpdates; + }); + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + adTaskCacheManager.markResultIndexQueried(detectorId); + listener.onResponse(0L); + } else { + listener.onFailure(exception); + } + }) + ); + }, e -> listener.onFailure(new AnomalyDetectionException(detectorId, "fail to get job")))); + }, e -> listener.onFailure(new AnomalyDetectionException(detectorId, "fail to get detector")))); + } } diff --git a/src/main/java/org/opensearch/ad/ProfileUtil.java b/src/main/java/org/opensearch/ad/ProfileUtil.java new file mode 100644 index 000000000..b2fc0dbea --- /dev/null +++ b/src/main/java/org/opensearch/ad/ProfileUtil.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class ProfileUtil { + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(AnomalyResult.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(CommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + public static void confirmDetectorRealtimeInitStatus( + AnomalyDetector detector, + long enabledTime, + Client client, + ActionListener listener + ) { + SearchRequest searchLatestResult = createRealtimeInittedEverRequest( + detector.getDetectorId(), + enabledTime, + detector.getResultIndex() + ); + client.search(searchLatestResult, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java b/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java index b7399e4c7..bf8cbb860 100644 --- a/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java +++ b/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java @@ -38,12 +38,17 @@ public class ADRealtimeTaskCache { // detector interval in milliseconds. private long detectorIntervalInMillis; + // we query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + // To avoid repeated query when there is no data, record whether we have done that or not. + private boolean queriedResultIndex; + public ADRealtimeTaskCache(String state, Float initProgress, String error, long detectorIntervalInMillis) { this.state = state; this.initProgress = initProgress; this.error = error; this.lastJobRunTime = Instant.now().toEpochMilli(); this.detectorIntervalInMillis = detectorIntervalInMillis; + this.queriedResultIndex = false; } public String getState() { @@ -74,6 +79,14 @@ public void setLastJobRunTime(long lastJobRunTime) { this.lastJobRunTime = lastJobRunTime; } + public boolean hasQueriedResultIndex() { + return queriedResultIndex; + } + + public void setQueriedResultIndex(boolean queriedResultIndex) { + this.queriedResultIndex = queriedResultIndex; + } + public boolean expired() { return lastJobRunTime + 2 * detectorIntervalInMillis < Instant.now().toEpochMilli(); } diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index 40fa8e2c4..965f65d05 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -1013,33 +1013,43 @@ public void clearPendingEntities(String detectorId) { } /** - * Check if realtime task field value changed or not by comparing with cache. - * 1. If new field value is null, will consider this field as not changed. - * 2. If any field value changed, will consider the realtime task changed. - * 3. If realtime task cache not found, will consider the realtime task changed. + * Check if realtime task field value change needed or not by comparing with cache. + * 1. If new field value is null, will consider changed needed to this field. + * 2. will consider the real time task change needed if + * 1) init progress is larger or the old init progress is null, or + * 2) if the state is different, and it is not changing from running to init. + * for other fields, as long as field values changed, will consider the realtime + * task change needed. We did this so that the init progress or state won't go backwards. + * 3. If realtime task cache not found, will consider the realtime task change needed. * * @param detectorId detector id * @param newState new task state * @param newInitProgress new init progress * @param newError new error - * @return true if realtime task changed comparing with realtime task cache. + * @return true if realtime task change needed. */ - public boolean isRealtimeTaskChanged(String detectorId, String newState, Float newInitProgress, String newError) { + public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Float newInitProgress, String newError) { if (realtimeTaskCaches.containsKey(detectorId)) { ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); - boolean stateChanged = false; - if (newState != null && !newState.equals(realtimeTaskCache.getState())) { - stateChanged = true; + boolean stateChangeNeeded = false; + String oldState = realtimeTaskCache.getState(); + if (newState != null + && !newState.equals(oldState) + && !(ADTaskState.INIT.name().equals(newState) && ADTaskState.RUNNING.name().equals(oldState))) { + stateChangeNeeded = true; } - boolean initProgressChanged = false; - if (newInitProgress != null && !newInitProgress.equals(realtimeTaskCache.getInitProgress())) { - initProgressChanged = true; + boolean initProgressChangeNeeded = false; + Float existingProgress = realtimeTaskCache.getInitProgress(); + if (newInitProgress != null + && !newInitProgress.equals(existingProgress) + && (existingProgress == null || newInitProgress > existingProgress)) { + initProgressChangeNeeded = true; } boolean errorChanged = false; if (newError != null && !newError.equals(realtimeTaskCache.getError())) { errorChanged = true; } - if (stateChanged || initProgressChanged || errorChanged) { + if (stateChangeNeeded || initProgressChangeNeeded || errorChanged) { return true; } return false; @@ -1351,4 +1361,33 @@ public void cleanExpiredHCBatchTaskRunStates() { cleanExpiredHCBatchTaskRunStatesSemaphore.release(); } } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * To avoid repeated query when there is no data, record whether we have done that or not. + * @param id detector id + */ + public void markResultIndexQueried(String id) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it + // cannot be found. If the cache is empty, we will return early and wait it for it to be + // initialized. + if (realtimeTaskCache != null) { + realtimeTaskCache.setQueriedResultIndex(true); + } + } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * + * @param id detector id + * @return whether we have queried result index or not. + */ + public boolean hasQueriedResultIndex(String id) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + if (realtimeTaskCache != null) { + return realtimeTaskCache.hasQueriedResultIndex(); + } + return false; + } } diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index 1b3a775c8..c1e61c536 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -2057,7 +2057,7 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( } error = Optional.ofNullable(error).orElse(""); - if (!adTaskCacheManager.isRealtimeTaskChanged(detectorId, newState, initProgress, error)) { + if (!adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId, newState, initProgress, error)) { // If task not changed, no need to update, just return listener.onResponse(null); return; @@ -3091,5 +3091,4 @@ public void maintainRunningRealtimeTasks() { } } } - } diff --git a/src/test/java/org/opensearch/ad/AbstractADTest.java b/src/test/java/org/opensearch/ad/AbstractADTest.java index 42a0e6f1c..3d27d89fd 100644 --- a/src/test/java/org/opensearch/ad/AbstractADTest.java +++ b/src/test/java/org/opensearch/ad/AbstractADTest.java @@ -24,17 +24,20 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.LogEvent; import org.apache.logging.log4j.core.Logger; import org.apache.logging.log4j.core.appender.AbstractAppender; +import org.apache.logging.log4j.core.config.Property; import org.apache.logging.log4j.core.layout.PatternLayout; import org.apache.logging.log4j.util.StackLocatorUtil; import org.opensearch.Version; @@ -79,8 +82,25 @@ public class AbstractADTest extends OpenSearchTestCase { * */ protected class TestAppender extends AbstractAppender { + private static final String EXCEPTION_CLASS = "exception_class"; + private static final String EXCEPTION_MSG = "exception_message"; + private static final String EXCEPTION_STACK_TRACE = "stacktrace"; + + Map, Map> exceptions; + // whether record exception and its stack trace or not. + // If you log(msg, exception), by default we won't record exception and its stack trace. + boolean recordExceptions; + protected TestAppender(String name) { - super(name, null, PatternLayout.createDefaultLayout(), true); + this(name, false); + } + + protected TestAppender(String name, boolean recordExceptions) { + super(name, null, PatternLayout.createDefaultLayout(), true, Property.EMPTY_ARRAY); + this.recordExceptions = recordExceptions; + if (recordExceptions) { + exceptions = new HashMap, Map>(); + } } public List messages = new ArrayList(); @@ -134,9 +154,47 @@ public int countMessage(String msg) { return countMessage(msg, false); } + public Boolean containExceptionClass(Class throwable, String className) { + Map throwableInformation = exceptions.get(throwable); + return Optional.ofNullable(throwableInformation).map(m -> m.get(EXCEPTION_CLASS)).map(s -> s.equals(className)).orElse(false); + } + + public Boolean containExceptionMsg(Class throwable, String msg) { + Map throwableInformation = exceptions.get(throwable); + return Optional + .ofNullable(throwableInformation) + .map(m -> m.get(EXCEPTION_MSG)) + .map(s -> ((String) s).contains(msg)) + .orElse(false); + } + + public Boolean containExceptionTrace(Class throwable, String traceElement) { + Map throwableInformation = exceptions.get(throwable); + return Optional + .ofNullable(throwableInformation) + .map(m -> m.get(EXCEPTION_STACK_TRACE)) + .map(s -> ((String) s).contains(traceElement)) + .orElse(false); + } + @Override public void append(LogEvent event) { messages.add(event.getMessage().getFormattedMessage()); + if (recordExceptions && event.getThrown() != null) { + Map throwableInformation = new HashMap(); + final Throwable throwable = event.getThrown(); + if (throwable.getClass().getCanonicalName() != null) { + throwableInformation.put(EXCEPTION_CLASS, throwable.getClass().getCanonicalName()); + } + if (throwable.getMessage() != null) { + throwableInformation.put(EXCEPTION_MSG, throwable.getMessage()); + } + if (throwable.getMessage() != null) { + StringBuilder stackTrace = new StringBuilder(ExceptionUtils.getStackTrace(throwable)); + throwableInformation.put(EXCEPTION_STACK_TRACE, stackTrace.toString()); + } + exceptions.put(throwable.getClass(), throwableInformation); + } } /** @@ -160,15 +218,19 @@ private String convertToRegex(String formattedStr) { /** * Set up test with junit that a warning was logged with log4j */ - protected void setUpLog4jForJUnit(Class cls) { + protected void setUpLog4jForJUnit(Class cls, boolean recordExceptions) { String loggerName = toLoggerName(callerClass(cls)); logger = (Logger) LogManager.getLogger(loggerName); Loggers.setLevel(logger, Level.DEBUG); - testAppender = new TestAppender(loggerName); + testAppender = new TestAppender(loggerName, recordExceptions); testAppender.start(); logger.addAppender(testAppender); } + protected void setUpLog4jForJUnit(Class cls) { + setUpLog4jForJUnit(cls, false); + } + private static String toLoggerName(final Class cls) { String canonicalName = cls.getCanonicalName(); return canonicalName != null ? canonicalName : cls.getName(); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index 0c3d35037..786780407 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -13,6 +13,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; @@ -22,12 +23,15 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.Locale; import java.util.Optional; @@ -40,6 +44,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -48,21 +53,28 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.FeatureData; import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.DiscoveryNodeFilterer; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -79,6 +91,8 @@ import org.opensearch.jobscheduler.spi.utils.LockService; import org.opensearch.threadpool.ThreadPool; +import com.google.common.collect.ImmutableList; + public class AnomalyDetectorJobRunnerTests extends AbstractADTest { @Mock @@ -114,9 +128,6 @@ public class AnomalyDetectorJobRunnerTests extends AbstractADTest { @Mock private ADTaskManager adTaskManager; - @Mock - private AnomalyDetectionIndices indexUtil; - private ExecuteADResultResponseRecorder recorder; @Mock @@ -124,6 +135,14 @@ public class AnomalyDetectorJobRunnerTests extends AbstractADTest { private AnomalyDetector detector; + @Mock + private ADTaskCacheManager adTaskCacheManager; + + @Mock + private NodeStateManager nodeStateManager; + + private AnomalyDetectionIndices anomalyDetectionIndices; + @BeforeClass public static void setUpBeforeClass() { setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); @@ -160,11 +179,9 @@ public void setup() throws Exception { runner.setSettings(settings); - AnomalyDetectionIndices anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); - IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); + anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - runner.setAnomalyDetectionIndices(indexUtil); + runner.setAnomalyDetectionIndices(anomalyDetectionIndices); lockService = new LockService(client, clusterService); doReturn(lockService).when(context).getLockService(); @@ -204,17 +221,28 @@ public void setup() throws Exception { return null; }).when(client).index(any(), any()); - recorder = new ExecuteADResultResponseRecorder(indexUtil, anomalyResultHandler, adTaskManager, nodeFilter, threadPool, client); - runner.setExecuteADResultResponseRecorder(recorder); - detector = TestHelpers.randomAnomalyDetectorWithEmptyFeature(); + when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(false); - NodeStateManager stateManager = mock(NodeStateManager.class); + detector = TestHelpers.randomAnomalyDetectorWithEmptyFeature(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); - runner.setNodeStateManager(stateManager); + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + runner.setNodeStateManager(nodeStateManager); + + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + nodeStateManager, + adTaskCacheManager, + 32 + ); + runner.setExecuteADResultResponseRecorder(recorder); } @Rule @@ -499,4 +527,240 @@ private void setUpJobParameter() { when(jobParameter.getWindowDelay()).thenReturn(new IntervalTimeConfiguration(10, ChronoUnit.SECONDS)); } + /** + * Test updateLatestRealtimeTask.confirmTotalRCFUpdatesFound + * @throws InterruptedException + */ + public Instant confirmInitializedSetup() { + // clear the appender created in setUp before creating another association; otherwise + // we will have unexpected error (e.g., some appender does not record messages even + // though we have configured to do so). + super.tearDownLog4jForJUnit(); + setUpLog4jForJUnit(ExecuteADResultResponseRecorder.class, true); + Schedule schedule = mock(IntervalSchedule.class); + when(jobParameter.getSchedule()).thenReturn(schedule); + Instant executionStartTime = Instant.now(); + when(schedule.getNextExecutionTime(executionStartTime)).thenReturn(executionStartTime.plusSeconds(5)); + + AnomalyResultResponse response = new AnomalyResultResponse( + 4d, + 0.993, + 1.01, + Collections.singletonList(new FeatureData("123", "abc", 0d)), + randomAlphaOfLength(4), + // not fully initialized + Long.valueOf(AnomalyDetectorSettings.NUM_MIN_SAMPLES - 1), + randomLong(), + // not an HC detector + false, + randomInt(), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + return executionStartTime; + } + + @SuppressWarnings("unchecked") + public void testFailtoFindDetector() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(AnomalyDetectionException.class, "fail to get detector")); + } + + @SuppressWarnings("unchecked") + public void testFailtoFindJob() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(AnomalyDetectionException.class, "fail to get job")); + } + + @SuppressWarnings("unchecked") + public void testEmptyDetector() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(AnomalyDetectionException.class, "fail to get detector")); + } + + @SuppressWarnings("unchecked") + public void testEmptyJob() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(AnomalyDetectionException.class, "fail to get job")); + } + + @SuppressWarnings("unchecked") + public void testMarkResultIndexQueried() throws IOException { + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setResultIndex(CommonName.CUSTOM_RESULT_INDEX_PREFIX + "index") + .build(); + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null))); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + + ActionListener listener = (ActionListener) args[1]; + + SearchResponse mockResponse = mock(SearchResponse.class); + int totalHits = 1001; + when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + listener.onResponse(mockResponse); + + return null; + }).when(client).search(any(), any(ActionListener.class)); + + // use a unmocked adTaskCacheManager to test the value of hasQueriedResultIndex has changed + Settings settings = Settings + .builder() + .put(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.getKey(), 2) + .put(AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS.getKey(), 100) + .build(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays.asList(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + MemoryTracker memoryTracker = mock(MemoryTracker.class); + adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + + // init real time task cache for the detector. We will do this during AnomalyResultTransportAction. + // Since we mocked the execution by returning anomaly result directly, we need to init it explicitly. + adTaskCacheManager.initRealtimeTaskCache(detector.getDetectorId(), 0); + + // recreate recorder since we need to use the unmocked adTaskCacheManager + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + nodeStateManager, + adTaskCacheManager, + 32 + ); + + assertEquals(false, adTaskCacheManager.hasQueriedResultIndex(detector.getDetectorId())); + + LockModel lock = new LockModel(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(client, times(1)).search(any(), any()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + ArgumentCaptor totalUpdates = ArgumentCaptor.forClass(Long.class); + verify(adTaskManager, times(1)) + .updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), totalUpdates.capture(), any(), any(), any()); + assertEquals(NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); + assertEquals(true, adTaskCacheManager.hasQueriedResultIndex(detector.getDetectorId())); + } } diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java index c0366e95a..d2846940c 100644 --- a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java @@ -35,6 +35,7 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.ResourceNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; @@ -44,6 +45,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.Feature; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.AnomalyResultAction; @@ -83,6 +85,8 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private Client client; private IndexAnomalyDetectorJobActionHandler handler; private AnomalyIndexHandler anomalyResultHandler; + private NodeStateManager nodeStateManager; + private ADTaskCacheManager adTaskCacheManager; @BeforeClass public static void setOnce() throws IOException { @@ -153,13 +157,21 @@ public void setUp() throws Exception { anomalyResultHandler = mock(AnomalyIndexHandler.class); + nodeStateManager = mock(NodeStateManager.class); + + adTaskCacheManager = mock(ADTaskCacheManager.class); + when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(true); + recorder = new ExecuteADResultResponseRecorder( anomalyDetectionIndices, anomalyResultHandler, adTaskManager, nodeFilter, threadPool, - client + client, + nodeStateManager, + adTaskCacheManager, + 32 ); handler = new IndexAnomalyDetectorJobActionHandler( @@ -318,6 +330,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); + verify(adTaskManager, times(1)).skipUpdateHCRealtimeTask(anyString(), anyString()); verify(threadPool, never()).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); } diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index c99314608..6f5111566 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -313,12 +313,12 @@ public void testRealtimeTaskCache() { String newState = ADTaskState.INIT.name(); Float newInitProgress = 0.0f; String newError = randomAlphaOfLength(5); - assertTrue(adTaskCacheManager.isRealtimeTaskChanged(detectorId1, newState, newInitProgress, newError)); + assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); // Init realtime task cache. adTaskCacheManager.initRealtimeTaskCache(detectorId1, 60_000); adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); - assertFalse(adTaskCacheManager.isRealtimeTaskChanged(detectorId1, newState, newInitProgress, newError)); + assertFalse(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); String detectorId2 = randomAlphaOfLength(10); @@ -331,7 +331,7 @@ public void testRealtimeTaskCache() { newState = ADTaskState.RUNNING.name(); newInitProgress = 1.0f; newError = "test error"; - assertTrue(adTaskCacheManager.isRealtimeTaskChanged(detectorId1, newState, newInitProgress, newError)); + assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); assertEquals(newInitProgress, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getInitProgress()); assertEquals(newState, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getState()); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index 7f67b8155..3643c7b4f 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -720,7 +720,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { String error = randomAlphaOfLength(5); ActionListener actionListener = mock(ActionListener.class); doReturn(node1).when(clusterService).localNode(); - when(adTaskCacheManager.isRealtimeTaskChanged(anyString(), anyString(), anyFloat(), anyString())).thenReturn(true); + when(adTaskCacheManager.isRealtimeTaskChangeNeeded(anyString(), anyString(), anyFloat(), anyString())).thenReturn(true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED)); diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java index 3178e8681..3847f4429 100644 --- a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java @@ -136,15 +136,19 @@ public void testHistoricalAnalysisExceedsMaxRunningTaskLimit() throws IOExceptio } public void testDisableADPlugin() throws IOException { - updateTransientSettings(ImmutableMap.of(AD_PLUGIN_ENABLED, false)); - - ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(new DetectionDateRange(startTime, endTime)); - RuntimeException exception = expectThrowsAnyOf( - ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), - () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) - ); - assertTrue(exception.getMessage().contains("AD plugin is disabled")); - updateTransientSettings(ImmutableMap.of(AD_PLUGIN_ENABLED, true)); + try { + updateTransientSettings(ImmutableMap.of(AD_PLUGIN_ENABLED, false)); + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(new DetectionDateRange(startTime, endTime)); + RuntimeException exception = expectThrowsAnyOf( + ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), + () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) + ); + assertTrue(exception.getMessage(), exception.getMessage().contains("AD plugin is disabled")); + updateTransientSettings(ImmutableMap.of(AD_PLUGIN_ENABLED, false)); + } finally { + // guarantee reset back to default + updateTransientSettings(ImmutableMap.of(AD_PLUGIN_ENABLED, true)); + } } public void testMultipleTasks() throws IOException, InterruptedException {