diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportState.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeState.java similarity index 92% rename from src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportState.java rename to src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeState.java index 38aa663f..83807cfa 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportState.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeState.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package com.amazon.opendistroforelasticsearch.ad.transport; +package com.amazon.opendistroforelasticsearch.ad; import java.time.Clock; import java.time.Duration; @@ -27,7 +27,7 @@ * Storing intermediate state during the execution of transport action * */ -public class TransportState { +public class NodeState implements ExpiringState { private String detectorId; // detector definition private AnomalyDetector detectorDef; @@ -35,8 +35,8 @@ public class TransportState { private int partitonNumber; // checkpoint fetch time private Instant lastAccessTime; - // last detection error. Used by DetectorStateHandler to check if the error for a - // detector has changed or not. If changed, trigger indexing. + // last detection error recorded in result index. Used by DetectorStateHandler + // to check if the error for a detector has changed or not. If changed, trigger indexing. private Optional lastDetectionError; // last training error. Used to save cold start error by a concurrent cold start thread. private Optional lastColdStartException; @@ -47,7 +47,7 @@ public class TransportState { // cold start running flag to prevent concurrent cold start private boolean coldStartRunning; - public TransportState(String detectorId, Clock clock) { + public NodeState(String detectorId, Clock clock) { this.detectorId = detectorId; this.detectorDef = null; this.partitonNumber = -1; @@ -182,7 +182,8 @@ private void refreshLastUpdateTime() { * @param stateTtl time to leave for the state * @return whether the transport state is expired */ + @Override public boolean expired(Duration stateTtl) { - return lastAccessTime.plus(stateTtl).isBefore(clock.instant()); + return expired(lastAccessTime, stateTtl, clock.instant()); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManager.java similarity index 70% rename from src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java rename to src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManager.java index f38df53a..104be239 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/TransportStateManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManager.java @@ -13,12 +13,13 @@ * permissions and limitations under the License. */ -package com.amazon.opendistroforelasticsearch.ad.transport; +package com.amazon.opendistroforelasticsearch.ad; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -29,6 +30,7 @@ import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -36,22 +38,25 @@ import org.elasticsearch.common.xcontent.XContentType; import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; -import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.transport.BackPressureRouting; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; /** - * ADStateManager is used by transport layer to manage AnomalyDetector object - * and the number of partitions for a detector id. + * NodeStateManager is used to manage states shared by transport and ml components + * like AnomalyDetector object * */ -public class TransportStateManager { - private static final Logger LOG = LogManager.getLogger(TransportStateManager.class); - private ConcurrentHashMap transportStates; +public class NodeStateManager implements MaintenanceState, CleanState { + private static final Logger LOG = LogManager.getLogger(NodeStateManager.class); + private ConcurrentHashMap states; private Client client; - private ModelManager modelManager; + private ModelPartitioner modelPartitioner; private NamedXContentRegistry xContentRegistry; private ClientUtil clientUtil; // map from ES node id to the node's backpressureMuter @@ -59,27 +64,42 @@ public class TransportStateManager { private final Clock clock; private final Settings settings; private final Duration stateTtl; + // last time we are throttled due to too much index pressure + private Instant lastIndexThrottledTime; public static final String NO_ERROR = "no_error"; - public TransportStateManager( + /** + * Constructor + * + * @param client Client to make calls to ElasticSearch + * @param xContentRegistry ES named content registry + * @param settings ES settings + * @param clientUtil AD Client utility + * @param clock A UTC clock + * @param stateTtl Max time to keep state in memory + * @param modelPartitioner Used to partiton a RCF forest + + */ + public NodeStateManager( Client client, NamedXContentRegistry xContentRegistry, - ModelManager modelManager, Settings settings, ClientUtil clientUtil, Clock clock, - Duration stateTtl + Duration stateTtl, + ModelPartitioner modelPartitioner ) { - this.transportStates = new ConcurrentHashMap<>(); + this.states = new ConcurrentHashMap<>(); this.client = client; - this.modelManager = modelManager; + this.modelPartitioner = modelPartitioner; this.xContentRegistry = xContentRegistry; this.clientUtil = clientUtil; this.backpressureMuter = new ConcurrentHashMap<>(); this.clock = clock; this.settings = settings; this.stateTtl = stateTtl; + this.lastIndexThrottledTime = Instant.MIN; } /** @@ -90,20 +110,30 @@ public TransportStateManager( * @throws LimitExceededException when there is no sufficient resource available */ public int getPartitionNumber(String adID, AnomalyDetector detector) { - TransportState state = transportStates.get(adID); + NodeState state = states.get(adID); if (state != null && state.getPartitonNumber() > 0) { return state.getPartitonNumber(); } - int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey(); - state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + int partitionNum = modelPartitioner.getPartitionedForestSizes(detector).getKey(); + state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); state.setPartitonNumber(partitionNum); return partitionNum; } + /** + * Get Detector config object if present + * @param adID detector Id + * @return the Detecor config object or empty Optional + */ + public Optional getAnomalyDetectorIfPresent(String adID) { + NodeState state = states.get(adID); + return Optional.ofNullable(state).map(NodeState::getDetectorDef); + } + public void getAnomalyDetector(String adID, ActionListener> listener) { - TransportState state = transportStates.get(adID); + NodeState state = states.get(adID); if (state != null && state.getDetectorDef() != null) { listener.onResponse(Optional.of(state.getDetectorDef())); } else { @@ -127,7 +157,12 @@ private ActionListener onGetDetectorResponse(String adID, ActionLis ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + // end execution if all features are disabled + if (detector.getEnabledFeatureIds().isEmpty()) { + listener.onFailure(new EndRunException(adID, CommonErrorMessages.ALL_FEATURES_DISABLED_ERR_MSG, true)); + return; + } + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); state.setDetectorDef(detector); listener.onResponse(Optional.of(detector)); @@ -145,13 +180,13 @@ private ActionListener onGetDetectorResponse(String adID, ActionLis * @param listener listener to handle get request */ public void getDetectorCheckpoint(String adID, ActionListener listener) { - TransportState state = transportStates.get(adID); + NodeState state = states.get(adID); if (state != null && state.doesCheckpointExists()) { listener.onResponse(Boolean.TRUE); return; } - GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelManager.getRcfModelId(adID, 0)); + GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelPartitioner.getRcfModelId(adID, 0)); clientUtil.asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener)); } @@ -161,7 +196,7 @@ private ActionListener onGetCheckpointResponse(String adID, ActionL if (response == null || !response.isExists()) { listener.onResponse(Boolean.FALSE); } else { - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); state.setCheckpointExists(true); listener.onResponse(Boolean.TRUE); } @@ -173,8 +208,9 @@ private ActionListener onGetCheckpointResponse(String adID, ActionL * * @param adID detector ID */ + @Override public void clear(String adID) { - transportStates.remove(adID); + states.remove(adID); } /** @@ -183,18 +219,9 @@ public void clear(String adID) { * java.util.ConcurrentModificationException. * */ + @Override public void maintenance() { - transportStates.entrySet().stream().forEach(entry -> { - String detectorId = entry.getKey(); - try { - TransportState state = entry.getValue(); - if (state.expired(stateTtl)) { - transportStates.remove(detectorId); - } - } catch (Exception e) { - LOG.warn("Failed to finish maintenance for detector id " + detectorId, e); - } - }); + maintenance(states, stateTtl); } public boolean isMuted(String nodeId) { @@ -232,7 +259,7 @@ public boolean hasRunningQuery(AnomalyDetector detector) { * @return last error for the detector */ public String getLastDetectionError(String adID) { - return Optional.ofNullable(transportStates.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR); + return Optional.ofNullable(states.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR); } /** @@ -241,7 +268,7 @@ public String getLastDetectionError(String adID) { * @param error error, can be null */ public void setLastDetectionError(String adID, String error) { - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); state.setLastDetectionError(error); } @@ -251,7 +278,7 @@ public void setLastDetectionError(String adID, String error) { * @param exception exception, can be null */ public void setLastColdStartException(String adID, AnomalyDetectionException exception) { - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); state.setLastColdStartException(exception); } @@ -262,7 +289,7 @@ public void setLastColdStartException(String adID, AnomalyDetectionException exc * @return last cold start exception for the detector */ public Optional fetchColdStartException(String adID) { - TransportState state = transportStates.get(adID); + NodeState state = states.get(adID); if (state == null) { return Optional.empty(); } @@ -279,7 +306,7 @@ public Optional fetchColdStartException(String adID) * @return running or not */ public boolean isColdStartRunning(String adID) { - TransportState state = transportStates.get(adID); + NodeState state = states.get(adID); if (state != null) { return state.isColdStartRunning(); } @@ -290,10 +317,24 @@ public boolean isColdStartRunning(String adID) { /** * Mark the cold start status of the detector * @param adID detector ID - * @param running whether it is running + * @return a callback when cold start is done */ - public void setColdStartRunning(String adID, boolean running) { - TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock)); - state.setColdStartRunning(running); + public Releasable markColdStartRunning(String adID) { + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setColdStartRunning(true); + return () -> { + NodeState nodeState = states.get(adID); + if (nodeState != null) { + nodeState.setColdStartRunning(false); + } + }; + } + + public Instant getLastIndexThrottledTime() { + return lastIndexThrottledTime; + } + + public void setLastIndexThrottledTime(Instant lastIndexThrottledTime) { + this.lastIndexThrottledTime = lastIndexThrottledTime; } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java index 35dd8a65..d386e422 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java @@ -68,7 +68,6 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; -import com.amazon.opendistroforelasticsearch.ad.transport.TransportStateManager; import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyIndexHandler; import com.amazon.opendistroforelasticsearch.ad.transport.handler.DetectionStateHandler; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; @@ -152,7 +151,7 @@ public void setup() throws Exception { AnomalyDetectionIndices anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); - TransportStateManager stateManager = mock(TransportStateManager.class); + NodeStateManager stateManager = mock(NodeStateManager.class); detectorStateHandler = new DetectionStateHandler( client, settings, diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManagerTests.java new file mode 100644 index 00000000..44f7caac --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateManagerTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Arrays; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.opendistroforelasticsearch.ad.util.Throttler; +import com.google.common.collect.ImmutableMap; + +public class NodeStateManagerTests extends ESTestCase { + private NodeStateManager stateManager; + private ModelPartitioner modelPartitioner; + private Client client; + private ClientUtil clientUtil; + private Clock clock; + private Duration duration; + private Throttler throttler; + private ThreadPool context; + private AnomalyDetector detectorToCheck; + private Settings settings; + private String adId = "123"; + + private GetResponse checkpointResponse; + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + modelPartitioner = mock(ModelPartitioner.class); + when(modelPartitioner.getPartitionedForestSizes(any(AnomalyDetector.class))).thenReturn(new SimpleImmutableEntry<>(2, 20)); + client = mock(Client.class); + settings = Settings + .builder() + .put("opendistro.anomaly_detection.max_retry_for_unresponsive_node", 3) + .put("opendistro.anomaly_detection.ad_mute_minutes", TimeValue.timeValueMinutes(10)) + .build(); + clock = mock(Clock.class); + duration = Duration.ofHours(1); + context = TestHelpers.createThreadPool(); + throttler = new Throttler(clock); + + clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, mock(ThreadPool.class)); + stateManager = new NodeStateManager(client, xContentRegistry(), settings, clientUtil, clock, duration, modelPartitioner); + + checkpointResponse = mock(GetResponse.class); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + stateManager = null; + modelPartitioner = null; + client = null; + clientUtil = null; + detectorToCheck = null; + } + + @SuppressWarnings("unchecked") + private String setupDetector() throws IOException { + detectorToCheck = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length >= 2); + + GetRequest request = null; + ActionListener listener = null; + if (args[0] instanceof GetRequest) { + request = (GetRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + listener + .onResponse( + TestHelpers.createGetResponse(detectorToCheck, detectorToCheck.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX) + ); + + return null; + }).when(client).get(any(), any(ActionListener.class)); + return detectorToCheck.getDetectorId(); + } + + @SuppressWarnings("unchecked") + private void setupCheckpoint(boolean responseExists) throws IOException { + when(checkpointResponse.isExists()).thenReturn(responseExists); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length >= 2); + + GetRequest request = null; + ActionListener listener = null; + if (args[0] instanceof GetRequest) { + request = (GetRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + listener.onResponse(checkpointResponse); + + return null; + }).when(client).get(any(), any(ActionListener.class)); + } + + public void testGetPartitionNumber() throws IOException, InterruptedException { + String detectorId = setupDetector(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); + for (int i = 0; i < 2; i++) { + // call two times should return the same result + int partitionNumber = stateManager.getPartitionNumber(detectorId, detector); + assertEquals(2, partitionNumber); + } + + // the 2nd call should directly fetch cached result + verify(modelPartitioner, times(1)).getPartitionedForestSizes(any()); + } + + public void testGetLastError() throws IOException, InterruptedException { + String error = "blah"; + assertEquals(NodeStateManager.NO_ERROR, stateManager.getLastDetectionError(adId)); + stateManager.setLastDetectionError(adId, error); + assertEquals(error, stateManager.getLastDetectionError(adId)); + } + + public void testShouldMute() { + String nodeId = "123"; + assertTrue(!stateManager.isMuted(nodeId)); + + when(clock.millis()).thenReturn(10000L); + IntStream.range(0, 4).forEach(j -> stateManager.addPressure(nodeId)); + + when(clock.millis()).thenReturn(20000L); + assertTrue(stateManager.isMuted(nodeId)); + + // > 15 minutes have passed, we should not mute anymore + when(clock.millis()).thenReturn(1000001L); + assertTrue(!stateManager.isMuted(nodeId)); + + // the backpressure counter should be reset + when(clock.millis()).thenReturn(100001L); + stateManager.resetBackpressureCounter(nodeId); + assertTrue(!stateManager.isMuted(nodeId)); + } + + public void testMaintenanceDoNothing() { + stateManager.maintenance(); + + verifyZeroInteractions(clock); + } + + public void testHasRunningQuery() throws IOException { + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + new ClientUtil(settings, client, throttler, context), + clock, + duration, + modelPartitioner + ); + + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), null); + SearchRequest dummySearchRequest = new SearchRequest(); + assertFalse(stateManager.hasRunningQuery(detector)); + throttler.insertFilteredQuery(detector.getDetectorId(), dummySearchRequest); + assertTrue(stateManager.hasRunningQuery(detector)); + } + + public void testGetAnomalyDetector() throws IOException, InterruptedException { + String detectorId = setupDetector(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(detectorToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void getCheckpointTestTemplate(boolean exists) throws IOException { + setupCheckpoint(exists); + when(clock.instant()).thenReturn(Instant.MIN); + stateManager + .getDetectorCheckpoint(adId, ActionListener.wrap(checkpointExists -> { assertEquals(exists, checkpointExists); }, exception -> { + for (StackTraceElement ste : exception.getStackTrace()) { + logger.info(ste); + } + assertTrue(false); + })); + } + + public void testCheckpointExists() throws IOException { + getCheckpointTestTemplate(true); + } + + public void testCheckpointNotExists() throws IOException { + getCheckpointTestTemplate(false); + } + + public void testMaintenanceNotRemove() throws IOException { + setupCheckpoint(true); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager.maintenance(); + stateManager + .getDetectorCheckpoint(adId, ActionListener.wrap(gotCheckpoint -> assertTrue(gotCheckpoint), exception -> assertTrue(false))); + verify(client, times(1)).get(any(), any()); + } + + public void testMaintenanceRemove() throws IOException { + setupCheckpoint(true); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(7200L)); + stateManager.maintenance(); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + verify(client, times(2)).get(any(), any()); + } + + public void testColdStartRunning() { + assertTrue(!stateManager.isColdStartRunning(adId)); + stateManager.markColdStartRunning(adId); + assertTrue(stateManager.isColdStartRunning(adId)); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateTests.java new file mode 100644 index 00000000..e29df63b --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/NodeStateTests.java @@ -0,0 +1,108 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import org.elasticsearch.test.ESTestCase; + +import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; + +public class NodeStateTests extends ESTestCase { + private NodeState state; + private Clock clock; + + @Override + public void setUp() throws Exception { + super.setUp(); + clock = mock(Clock.class); + state = new NodeState("123", clock); + } + + private Duration duration = Duration.ofHours(1); + + public void testMaintenanceNotRemoveSingle() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceNotRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + state.setLastDetectionError(null); + + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceRemoveLastError() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state + .setDetectorDef( + + TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null) + ); + state.setLastDetectionError(null); + + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); + } + + public void testMaintenancRemoveDetector() throws IOException { + when(clock.instant()).thenReturn(Instant.MIN); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + when(clock.instant()).thenReturn(Instant.MAX); + assertTrue(state.expired(duration)); + + } + + public void testMaintenanceFlagNotRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenancFlagRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.MIN); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceLastColdStartRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setLastColdStartException(new AnomalyDetectionException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); + } + + public void testMaintenanceLastColdStartNotRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1_000_000L)); + state.setLastColdStartException(new AnomalyDetectionException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java index 172df264..77690aae 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java @@ -42,6 +42,9 @@ import test.com.amazon.opendistroforelasticsearch.ad.util.JsonDeserializer; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.NodeStateManager; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; @@ -64,9 +67,12 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); - TransportStateManager tarnsportStatemanager = mock(TransportStateManager.class); + NodeStateManager tarnsportStatemanager = mock(NodeStateManager.class); ModelManager modelManager = mock(ModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache entityCache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(entityCache); action = new CronTransportAction( threadPool, @@ -75,7 +81,8 @@ public void setUp() throws Exception { actionFilters, tarnsportStatemanager, modelManager, - featureManager + featureManager, + cacheProvider ); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java index 8defa7b4..614925a8 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java @@ -46,6 +46,9 @@ import test.com.amazon.opendistroforelasticsearch.ad.util.JsonDeserializer; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.NodeStateManager; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; @@ -69,18 +72,22 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); - TransportStateManager tarnsportStatemanager = mock(TransportStateManager.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); ModelManager modelManager = mock(ModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache entityCache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(entityCache); action = new DeleteModelTransportAction( threadPool, clusterService, transportService, actionFilters, - tarnsportStatemanager, + nodeStateManager, modelManager, - featureManager + featureManager, + cacheProvider ); }