diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java index d8a3a76a..a643b29e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java @@ -15,35 +15,81 @@ package com.amazon.opendistroforelasticsearch.ad.ml; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.locks.ReentrantLock; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; +import org.elasticsearch.index.reindex.ScrollableHitSource; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.util.BulkUtil; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; +import com.google.common.util.concurrent.RateLimiter; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; /** * DAO for model checkpoints. */ public class CheckpointDao { - protected static final String DOC_TYPE = "_doc"; - protected static final String FIELD_MODEL = "model"; - public static final String TIMESTAMP = "timestamp"; - private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; + + // ====================================== + // Model serialization/deserialization + // ====================================== + public static final String ENTITY_SAMPLE = "sp"; + public static final String ENTITY_RCF = "rcf"; + public static final String ENTITY_THRESHOLD = "th"; + public static final String FIELD_MODEL = "model"; + public static final String TIMESTAMP = "timestamp"; + public static final String DETECTOR_ID = "detectorId"; // dependencies private final Client client; @@ -52,17 +98,60 @@ public class CheckpointDao { // configuration private final String indexName; + private Gson gson; + private RandomCutForestSerDe rcfSerde; + + private ConcurrentLinkedQueue> requests; + private final ReentrantLock lock; + private final Class thresholdingModelClass; + private final Duration checkpointInterval; + private final Clock clock; + private final AnomalyDetectionIndices indexUtil; + private final RateLimiter bulkRateLimiter; + private final int maxBulkRequestSize; + /** * Constructor with dependencies and configuration. * * @param client ES search client * @param clientUtil utility with ES client * @param indexName name of the index for model checkpoints + * @param gson accessor to Gson functionality + * @param rcfSerde accessor to rcf serialization/deserialization + * @param thresholdingModelClass thresholding model's class + * @param clock a UTC clock + * @param checkpointInterval how often we should save a checkpoint + * @param indexUtil Index utility methods + * @param maxBulkRequestSize max number of index request a bulk can contain + * @param bulkPerSecond bulk requests per second */ - public CheckpointDao(Client client, ClientUtil clientUtil, String indexName) { + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + RandomCutForestSerDe rcfSerde, + Class thresholdingModelClass, + Clock clock, + Duration checkpointInterval, + AnomalyDetectionIndices indexUtil, + int maxBulkRequestSize, + double bulkPerSecond + ) { this.client = client; this.clientUtil = clientUtil; this.indexName = indexName; + this.gson = gson; + this.rcfSerde = rcfSerde; + this.requests = new ConcurrentLinkedQueue<>(); + this.lock = new ReentrantLock(); + this.thresholdingModelClass = thresholdingModelClass; + this.clock = clock; + this.checkpointInterval = checkpointInterval; + this.indexUtil = indexUtil; + this.maxBulkRequestSize = maxBulkRequestSize; + // 1 bulk request per minute. 1 / 60 seconds = 0. 02 + this.bulkRateLimiter = RateLimiter.create(bulkPerSecond); } /** @@ -79,12 +168,15 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint) { source.put(FIELD_MODEL, modelCheckpoint); source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - clientUtil - .timedRequest( - new IndexRequest(indexName, DOC_TYPE, modelId).source(source), - logger, - client::index - ); + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointSync(source, modelId); + } else { + onCheckpointNotExist(source, modelId, false, null); + } + } + + private void saveModelCheckpointSync(Map source, String modelId) { + clientUtil.timedRequest(new IndexRequest(indexName).id(modelId).source(source), logger, client::index); } /** @@ -98,14 +190,149 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint, ActionLis Map source = new HashMap<>(); source.put(FIELD_MODEL, modelCheckpoint); source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + onCheckpointNotExist(source, modelId, true, listener); + } + } + + private void onCheckpointNotExist(Map source, String modelId, boolean isAsync, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + if (isAsync) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + saveModelCheckpointSync(source, modelId); + } + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + if (isAsync) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + saveModelCheckpointSync(source, modelId); + } + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); + } + + private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { clientUtil .asyncRequest( - new IndexRequest(indexName, DOC_TYPE, modelId).source(source), + new IndexRequest(indexName).id(modelId).source(source), client::index, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) ); } + /** + * Bulk writing model states prepared previously + */ + public void flush() { + try { + // in case that other threads are doing bulk as well. + if (!lock.tryLock()) { + return; + } + if (requests.size() > 0 && bulkRateLimiter.tryAcquire()) { + final BulkRequest bulkRequest = new BulkRequest(); + // at most 1000 index requests per bulk + for (int i = 0; i < maxBulkRequestSize; i++) { + DocWriteRequest req = requests.poll(); + if (req == null) { + break; + } + + bulkRequest.add(req); + } + if (indexUtil.doesCheckpointIndexExist()) { + flush(bulkRequest); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + flush(bulkRequest); + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + flush(bulkRequest); + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); + } + } + } finally { + if (lock.isHeldByCurrentThread()) { + lock.unlock(); + } + } + } + + private void flush(BulkRequest bulkRequest) { + clientUtil.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(r -> { + if (r.hasFailures()) { + requests.addAll(BulkUtil.getIndexRequestToRetry(bulkRequest, r)); + } + }, e -> { + logger.error("Failed bulking checkpoints", e); + // retry during next bulk. + for (DocWriteRequest req : bulkRequest.requests()) { + requests.add(req); + } + })); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again. + * @param modelState Model state + * @param modelId Model Id + */ + public void write(ModelState modelState, String modelId) { + write(modelState, modelId, false); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again, except this + * is from cold start. This method will update the input state's last + * checkpoint time if the checkpoint is staged (ready to be written in the + * next batch). + * @param modelState Model state + * @param modelId Model Id + * @param coldStart whether the checkpoint comes from cold start + */ + public void write(ModelState modelState, String modelId, boolean coldStart) { + Instant instant = modelState.getLastCheckpointTime(); + // Instant.MIN is the default value. We don't save until we are sure. + if ((instant == Instant.MIN || instant.plus(checkpointInterval).isAfter(clock.instant())) && !coldStart) { + return; + } + // It is possible 2 states of the same model id gets saved: one overwrite another. + // This can happen if previous checkpoint hasn't been saved to disk, while the + // 1st one creates a new state without restoring. + if (modelState.getModel() != null) { + Map source = new HashMap<>(); + source.put(DETECTOR_ID, modelState.getDetectorId()); + source.put(FIELD_MODEL, toCheckpoint(modelState.getModel())); + source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + requests.add(new IndexRequest(indexName).id(modelId).source(source)); + modelState.setLastCheckpointTime(clock.instant()); + if (requests.size() >= maxBulkRequestSize) { + flush(); + } + } + } + /** * Returns the checkpoint for the model. * @@ -117,33 +344,24 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint, ActionLis @Deprecated public Optional getModelCheckpoint(String modelId) { return clientUtil - .timedRequest(new GetRequest(indexName, DOC_TYPE, modelId), logger, client::get) + .timedRequest(new GetRequest(indexName, modelId), logger, client::get) .filter(GetResponse::isExists) .map(GetResponse::getSource) .map(source -> (String) source.get(FIELD_MODEL)); } - /** - * Returns to listener the checkpoint for the model. - * - * @param modelId id of the model - * @param listener onResponse is called with the model checkpoint, or empty for no such model - */ - public void getModelCheckpoint(String modelId, ActionListener> listener) { - clientUtil - .asyncRequest( - new GetRequest(indexName, DOC_TYPE, modelId), - client::get, - ActionListener.wrap(response -> listener.onResponse(processModelCheckpoint(response)), listener::onFailure) - ); - } - - private Optional processModelCheckpoint(GetResponse response) { - return Optional - .ofNullable(response) - .filter(GetResponse::isExists) - .map(GetResponse::getSource) - .map(source -> (String) source.get(FIELD_MODEL)); + String toCheckpoint(EntityModel model) { + return AccessController.doPrivileged((PrivilegedAction) () -> { + JsonObject json = new JsonObject(); + json.add(ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); + if (model.getRcf() != null) { + json.addProperty(ENTITY_RCF, rcfSerde.toJson(model.getRcf())); + } + if (model.getThreshold() != null) { + json.addProperty(ENTITY_THRESHOLD, gson.toJson(model.getThreshold())); + } + return gson.toJson(json); + }); } /** @@ -155,7 +373,7 @@ private Optional processModelCheckpoint(GetResponse response) { */ @Deprecated public void deleteModelCheckpoint(String modelId) { - clientUtil.timedRequest(new DeleteRequest(indexName, DOC_TYPE, modelId), logger, client::delete); + clientUtil.timedRequest(new DeleteRequest(indexName, modelId), logger, client::delete); } /** @@ -167,9 +385,128 @@ public void deleteModelCheckpoint(String modelId) { public void deleteModelCheckpoint(String modelId, ActionListener listener) { clientUtil .asyncRequest( - new DeleteRequest(indexName, DOC_TYPE, modelId), + new DeleteRequest(indexName, modelId), client::delete, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) ); } + + /** + * Delete checkpoints associated with a detector. Used in multi-entity detector. + * @param detectorID Detector Id + */ + public void deleteModelCheckpointByDetectorId(String detectorID) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(CommonName.CHECKPOINT_INDEX_NAME) + .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of detector {}", detectorID); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, detectorID); + } + // if 0 docs get deleted, it means we cannot find matching docs + logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); + } + })); + } + + private void logFailure(BulkByScrollResponse response, String detectorID) { + if (response.isTimedOut()) { + logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + private Entry fromEntityModelCheckpoint(Map checkpoint, String modelId) { + try { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + String model = (String) (checkpoint.get(FIELD_MODEL)); + JsonObject json = JsonParser.parseString(model).getAsJsonObject(); + ArrayDeque samples = new ArrayDeque<>( + Arrays.asList(this.gson.fromJson(json.getAsJsonArray(ENTITY_SAMPLE), new double[0][0].getClass())) + ); + RandomCutForest rcf = null; + if (json.has(ENTITY_RCF)) { + rcf = rcfSerde.fromJson(json.getAsJsonPrimitive(ENTITY_RCF).getAsString()); + } + ThresholdingModel threshold = null; + if (json.has(ENTITY_THRESHOLD)) { + threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass); + } + + String lastCheckpointTimeString = (String) (checkpoint.get(TIMESTAMP)); + Instant timestamp = Instant.parse(lastCheckpointTimeString); + return new SimpleImmutableEntry<>(new EntityModel(modelId, samples, rcf, threshold), timestamp); + }); + } catch (RuntimeException e) { + logger.warn("Exception while deserializing checkpoint", e); + throw e; + } + } + + /** + * Read a checkpoint from the index and return the EntityModel object + * @param modelId Model Id + * @param listener Listener to return the EntityModel object + */ + public void restoreModelCheckpoint(String modelId, ActionListener>> listener) { + clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + listener.onResponse(Optional.of(fromEntityModelCheckpoint(checkpointString.get(), modelId))); + } else { + listener.onResponse(Optional.empty()); + } + }, listener::onFailure)); + } + + /** + * Returns to listener the checkpoint for the model. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getModelCheckpoint(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> listener.onResponse(processModelCheckpoint(response)), listener::onFailure) + ); + } + + private Optional processModelCheckpoint(GetResponse response) { + return Optional + .ofNullable(response) + .filter(GetResponse::isExists) + .map(GetResponse::getSource) + .map(source -> (String) source.get(FIELD_MODEL)); + } + + private Optional> processRawCheckpoint(GetResponse response) { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java new file mode 100644 index 00000000..23a35057 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.util; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; + +public class BulkUtil { + private static final Logger logger = LogManager.getLogger(BulkUtil.class); + + public static List> getIndexRequestToRetry(BulkRequest bulkRequest, BulkResponse bulkResponse) { + List> res = new ArrayList<>(); + + Set failedId = new HashSet<>(); + for (BulkItemResponse response : bulkResponse.getItems()) { + if (response.isFailed()) { + failedId.add(response.getId()); + } + } + + for (DocWriteRequest request : bulkRequest.requests()) { + if (failedId.contains(request.id())) { + res.add(request); + } + } + return res; + } +} diff --git a/src/main/resources/mappings/checkpoint.json b/src/main/resources/mappings/checkpoint.json new file mode 100644 index 00000000..8fe8a099 --- /dev/null +++ b/src/main/resources/mappings/checkpoint.json @@ -0,0 +1,18 @@ +{ + "dynamic": true, + "_meta": { + "schema_version": 1 + }, + "properties": { + "detectorId": { + "type": "keyword" + }, + "model": { + "type": "text" + }, + "timestamp": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java index 618d1254..d5e77416 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java @@ -15,6 +15,7 @@ package com.amazon.opendistroforelasticsearch.ad.ml; +import static com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao.FIELD_MODEL; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -23,37 +24,75 @@ import static org.mockito.Matchers.anyObject; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.Month; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.function.BiConsumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.shard.ShardId; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; +import com.google.gson.Gson; +@RunWith(PowerMockRunner.class) +@PrepareForTest({ Gson.class }) public class CheckpointDaoTests { + private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); private CheckpointDao checkpointDao; @@ -67,6 +106,15 @@ public class CheckpointDaoTests { @Mock private GetResponse getResponse; + @Mock + private RandomCutForestSerDe rcfSerde; + + @Mock + private Clock clock; + + @Mock + private AnomalyDetectionIndices indexUtil; + // configuration private String indexName; @@ -75,24 +123,47 @@ public class CheckpointDaoTests { private String model; private Map docSource; + private Gson gson; + private Class thresholdingModelClass; + private int maxBulkSize; + @Before public void setup() { MockitoAnnotations.initMocks(this); indexName = "testIndexName"; - checkpointDao = new CheckpointDao(client, clientUtil, indexName); + gson = PowerMockito.mock(Gson.class); + + thresholdingModelClass = HybridThresholdingModel.class; + + when(clock.instant()).thenReturn(Instant.now()); + + maxBulkSize = 10; + + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + rcfSerde, + thresholdingModelClass, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + indexUtil, + maxBulkSize, + 200.0 + ); + + when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); modelId = "testModelId"; model = "testModel"; docSource = new HashMap<>(); - docSource.put(CheckpointDao.FIELD_MODEL, model); + docSource.put(FIELD_MODEL, model); } - @Test - public void putModelCheckpoint_getIndexRequest() { - checkpointDao.putModelCheckpoint(modelId, model); - + private void verifySuccessfulPutModelCheckpointSync() { ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); verify(clientUtil) .timedRequest( @@ -102,14 +173,65 @@ public void putModelCheckpoint_getIndexRequest() { ); IndexRequest indexRequest = indexRequestCaptor.getValue(); assertEquals(indexName, indexRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, indexRequest.type()); assertEquals(modelId, indexRequest.id()); - Set expectedSourceKeys = new HashSet(Arrays.asList(CheckpointDao.FIELD_MODEL, CheckpointDao.TIMESTAMP)); + Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODEL, CheckpointDao.TIMESTAMP)); assertEquals(expectedSourceKeys, indexRequest.sourceAsMap().keySet()); - assertEquals(model, indexRequest.sourceAsMap().get(CheckpointDao.FIELD_MODEL)); + assertEquals(model, indexRequest.sourceAsMap().get(FIELD_MODEL)); assertNotNull(indexRequest.sourceAsMap().get(CheckpointDao.TIMESTAMP)); } + @Test + public void putModelCheckpoint_getIndexRequest() { + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_no_checkpoint_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_index_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verify(clientUtil, never()).timedRequest(any(), any(), any()); + } + @Test public void getModelCheckpoint_returnExpected() { ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); @@ -129,7 +251,6 @@ public void getModelCheckpoint_returnExpected() { assertEquals(model, result.get()); GetRequest getRequest = getRequestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); } @@ -158,13 +279,11 @@ public void deleteModelCheckpoint_getDeleteRequest() { ); DeleteRequest deleteRequest = deleteRequestCaptor.getValue(); assertEquals(indexName, deleteRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, deleteRequest.type()); assertEquals(modelId, deleteRequest.id()); } - @Test @SuppressWarnings("unchecked") - public void putModelCheckpoint_callListener_whenCompleted() { + private void verifyPutModelCheckpointAsync() { ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -177,11 +296,10 @@ public void putModelCheckpoint_callListener_whenCompleted() { IndexRequest indexRequest = requestCaptor.getValue(); assertEquals(indexName, indexRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, indexRequest.type()); assertEquals(modelId, indexRequest.id()); - Set expectedSourceKeys = new HashSet(Arrays.asList(CheckpointDao.FIELD_MODEL, CheckpointDao.TIMESTAMP)); + Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODEL, CheckpointDao.TIMESTAMP)); assertEquals(expectedSourceKeys, indexRequest.sourceAsMap().keySet()); - assertEquals(model, indexRequest.sourceAsMap().get(CheckpointDao.FIELD_MODEL)); + assertEquals(model, indexRequest.sourceAsMap().get(FIELD_MODEL)); assertNotNull(indexRequest.sourceAsMap().get(CheckpointDao.TIMESTAMP)); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); @@ -190,6 +308,54 @@ public void putModelCheckpoint_callListener_whenCompleted() { assertEquals(null, response); } + @Test + public void putModelCheckpoint_callListener_whenCompleted() { + verifyPutModelCheckpointAsync(); + } + + @Test + public void putModelCheckpoint_callListener_no_checkpoint_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + @Test + public void putModelCheckpoint_callListener_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + @SuppressWarnings("unchecked") + @Test + public void putModelCheckpoint_callListener_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + ActionListener listener = mock(ActionListener.class); + checkpointDao.putModelCheckpoint(modelId, model, listener); + + verify(clientUtil, never()).asyncRequest(any(), any(), any()); + } + @Test @SuppressWarnings("unchecked") public void getModelCheckpoint_returnExpectedToListener() { @@ -207,7 +373,6 @@ public void getModelCheckpoint_returnExpectedToListener() { GetRequest getRequest = requestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -232,7 +397,6 @@ public void getModelCheckpoint_returnEmptyToListener_whenModelNotFound() { GetRequest getRequest = requestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -255,7 +419,6 @@ public void deleteModelCheckpoint_callListener_whenCompleted() { DeleteRequest deleteRequest = requestCaptor.getValue(); assertEquals(indexName, deleteRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, deleteRequest.type()); assertEquals(modelId, deleteRequest.id()); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); @@ -263,4 +426,283 @@ public void deleteModelCheckpoint_callListener_whenCompleted() { Void response = responseCaptor.getValue(); assertEquals(null, response); } + + private BulkResponse createBulkResponse(int succeeded, int failed, String[] failedId) { + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[succeeded + failed]; + + ShardId shardId = new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "", 1); + int i = 0; + for (; i < failed; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new BulkItemResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + CommonName.MAPPING_TYPE, + failedId[i], + new VersionConflictEngineException(shardId, "id", "test") + ) + ); + } + + for (; i < failed + succeeded; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new UpdateResponse(shardId, CommonName.MAPPING_TYPE, "1", 0L, 1L, 1L, DocWriteResponse.Result.CREATED) + ); + } + + return new BulkResponse(bulkItemResponses, 507); + } + + @SuppressWarnings("unchecked") + @Test + public void flush_less_than_1k() { + int writeRequests = maxBulkSize - 1; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + public void flush_more_than_1k() { + int writeRequests = maxBulkSize + 1; + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(maxBulkSize, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + // should trigger auto flush + checkpointDao.write(state, state.getModelId(), true); + } + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @Test + public void flush_more_than_1k_has_index() { + flush_more_than_1k(); + } + + @Test + public void flush_more_than_1k_no_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + flush_more_than_1k(); + } + + @Test + public void flush_more_than_1k_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + flush_more_than_1k(); + } + + @SuppressWarnings("unchecked") + @Test + public void flush_more_than_1k_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void bulk_has_failure() throws InterruptedException { + int writeRequests = maxBulkSize - 1; + int failureCount = 1; + String[] failedId = new String[failureCount]; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + if (i < failureCount) { + failedId[i] = state.getModelId(); + } + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), failureCount, failedId)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(failureCount, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void bulk_all_failure() throws InterruptedException { + int writeRequests = maxBulkSize - 1; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onFailure(new RuntimeException("")); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void checkpoint_saved_less_than_1_hr() { + ModelState state = MLUtil.randomModelState(); + state.setLastCheckpointTime(Instant.now()); + checkpointDao.write(state, state.getModelId()); + + checkpointDao.flush(); + + verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void checkpoint_coldstart_checkpoint() { + ModelState state = MLUtil.randomModelState(); + state.setLastCheckpointTime(Instant.now()); + // cold start checkpoint will save whatever + checkpointDao.write(state, state.getModelId(), true); + + checkpointDao.flush(); + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void restore() throws IOException { + ModelState state = MLUtil.randomNonEmptyModelState(); + EntityModel modelToSave = state.getModel(); + + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + new Gson(), + new RandomCutForestSerDe(), + thresholdingModelClass, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + indexUtil, + maxBulkSize, + 2 + ); + + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + Map source = new HashMap<>(); + source.put(CheckpointDao.DETECTOR_ID, state.getDetectorId()); + source.put(CheckpointDao.FIELD_MODEL, checkpointDao.toCheckpoint(modelToSave)); + source.put(CheckpointDao.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); + when(getResponse.getSource()).thenReturn(source); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(getResponse); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); + + ActionListener>> listener = mock(ActionListener.class); + checkpointDao.restoreModelCheckpoint(modelId, listener); + + ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional> response = responseCaptor.getValue(); + assertTrue(response.isPresent()); + Entry entry = response.get(); + OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); + assertEquals(2020, utcTime.getYear()); + assertEquals(Month.OCTOBER, utcTime.getMonth()); + assertEquals(11, utcTime.getDayOfMonth()); + assertEquals(22, utcTime.getHour()); + assertEquals(58, utcTime.getMinute()); + assertEquals(23, utcTime.getSecond()); + + EntityModel model = entry.getKey(); + Queue queue = model.getSamples(); + Queue samplesToSave = modelToSave.getSamples(); + assertEquals(samplesToSave.size(), queue.size()); + assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); + logger.info(modelToSave.getRcf()); + logger.info(model.getRcf()); + assertEquals(modelToSave.getRcf().getTotalUpdates(), model.getRcf().getTotalUpdates()); + assertTrue(model.getThreshold() != null); + } }