From 39db0f5e42b1f2601b2102ddbeca8350c9f4fbb5 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 14 Oct 2020 13:22:47 -0700 Subject: [PATCH 1/5] Add multi-entity checkpoints read and write MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We need checkpoints to save states and models on disk. In single-entity detectors, we store rcf and threshold models separately in different docs.  In multi-entity detectors, we need to store them together as we don't use distributed models anymore.  We also need to store recent sample history when the models are not ready. This PR adds functions to serialize models and samples together in one doc and deserialize them when needed.  Also, we bulk indexing multi-entity detectors' checkpoints. Bulk requests will yield much better performance than single-document index requests   Testing done: 1. added unit tests. 2. end-to-end testing --- .../ad/ml/CheckpointDao.java | 410 +++++++++++++-- src/main/resources/mappings/checkpoint.json | 18 + .../ad/ml/CheckpointDaoTests.java | 480 +++++++++++++++++- 3 files changed, 853 insertions(+), 55 deletions(-) create mode 100644 src/main/resources/mappings/checkpoint.json diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java index d8a3a76a..ffc5b6b5 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java @@ -15,35 +15,81 @@ package com.amazon.opendistroforelasticsearch.ad.ml; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.locks.ReentrantLock; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; +import org.elasticsearch.index.reindex.ScrollableHitSource; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.util.BulkUtil; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; +import com.google.common.util.concurrent.RateLimiter; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; /** * DAO for model checkpoints. */ public class CheckpointDao { - protected static final String DOC_TYPE = "_doc"; - protected static final String FIELD_MODEL = "model"; - public static final String TIMESTAMP = "timestamp"; - private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; + + // ====================================== + // Model serialization/deserialization + // ====================================== + public static final String ENTITY_SAMPLE = "sp"; + public static final String ENTITY_RCF = "rcf"; + public static final String ENTITY_THRESHOLD = "th"; + public static final String FIELD_MODEL = "model"; + public static final String TIMESTAMP = "timestamp"; + public static final String DETECTOR_ID = "detectorId"; // dependencies private final Client client; @@ -52,17 +98,60 @@ public class CheckpointDao { // configuration private final String indexName; + private Gson gson; + private RandomCutForestSerDe rcfSerde; + + private ConcurrentLinkedQueue> requests; + private final ReentrantLock lock; + private final Class thresholdingModelClass; + private final Duration checkpointInterval; + private final Clock clock; + private final AnomalyDetectionIndices indexUtil; + private final RateLimiter bulkRateLimiter; + private final int maxBulkRequestSize; + /** * Constructor with dependencies and configuration. * * @param client ES search client * @param clientUtil utility with ES client * @param indexName name of the index for model checkpoints + * @param gson accessor to Gson functionality + * @param rcfSerde accessor to rcf serialization/deserialization + * @param thresholdingModelClass thresholding model's class + * @param clock a UTC clock + * @param checkpointInterval how often we should save a checkpoint + * @param indexUtil Index utility methods + * @param maxBulkRequestSize max number of index request a bulk can contain + * @param bulkPerSecond bulk requests per second */ - public CheckpointDao(Client client, ClientUtil clientUtil, String indexName) { + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + RandomCutForestSerDe rcfSerde, + Class thresholdingModelClass, + Clock clock, + Duration checkpointInterval, + AnomalyDetectionIndices indexUtil, + int maxBulkRequestSize, + double bulkPerSecond + ) { this.client = client; this.clientUtil = clientUtil; this.indexName = indexName; + this.gson = gson; + this.rcfSerde = rcfSerde; + this.requests = new ConcurrentLinkedQueue<>(); + this.lock = new ReentrantLock(); + this.thresholdingModelClass = thresholdingModelClass; + this.clock = clock; + this.checkpointInterval = checkpointInterval; + this.indexUtil = indexUtil; + this.maxBulkRequestSize = maxBulkRequestSize; + // 1 bulk request per minute. 1 / 60 seconds = 0. 02 + this.bulkRateLimiter = RateLimiter.create(bulkPerSecond); } /** @@ -79,12 +168,28 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint) { source.put(FIELD_MODEL, modelCheckpoint); source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - clientUtil - .timedRequest( - new IndexRequest(indexName, DOC_TYPE, modelId).source(source), - logger, - client::index - ); + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointSync(source, modelId); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + saveModelCheckpointSync(source, modelId); + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + saveModelCheckpointSync(source, modelId); + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); + } + } + + private void saveModelCheckpointSync(Map source, String modelId) { + clientUtil.timedRequest(new IndexRequest(indexName).id(modelId).source(source), logger, client::index); } /** @@ -98,14 +203,137 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint, ActionLis Map source = new HashMap<>(); source.put(FIELD_MODEL, modelCheckpoint); source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + saveModelCheckpointAsync(source, modelId, listener); + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); + } + } + + private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { clientUtil .asyncRequest( - new IndexRequest(indexName, DOC_TYPE, modelId).source(source), + new IndexRequest(indexName).id(modelId).source(source), client::index, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) ); } + /** + * Bulk writing model states prepared previously + */ + public void flush() { + try { + // in case that other threads are doing bulk as well. + if (!lock.tryLock()) { + return; + } + if (requests.size() > 0 && bulkRateLimiter.tryAcquire()) { + final BulkRequest bulkRequest = new BulkRequest(); + // at most 1000 index requests per bulk + for (int i = 0; i < maxBulkRequestSize; i++) { + DocWriteRequest req = requests.poll(); + if (req == null) { + break; + } + + bulkRequest.add(req); + } + if (indexUtil.doesCheckpointIndexExist()) { + flush(bulkRequest); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + flush(bulkRequest); + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + flush(bulkRequest); + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); + } + } + } finally { + if (lock.isHeldByCurrentThread()) { + lock.unlock(); + } + } + } + + private void flush(BulkRequest bulkRequest) { + clientUtil.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(r -> { + if (r.hasFailures() == false) { + logger.debug("Succeeded in bulking checkpoints"); + } else { + requests.addAll(BulkUtil.getIndexRequestToRetry(bulkRequest, r)); + } + }, e -> { + logger.error("Failed bulking checkpoints", e); + // retry during next bulk. + for (DocWriteRequest req : bulkRequest.requests()) { + requests.add(req); + } + })); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again. + * @param modelState Model state + * @param modelId Model Id + */ + public void write(ModelState modelState, String modelId) { + write(modelState, modelId, false); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again, except this + * is from cold start. + * @param modelState Model state + * @param modelId Model Id + * @param coldStart whether the checkpoint comes from cold start + */ + public void write(ModelState modelState, String modelId, boolean coldStart) { + Instant instant = modelState.getLastCheckpointTime(); + // Instant.MIN is the default value. We don't save until we are sure. + if ((instant == Instant.MIN || instant.plus(checkpointInterval).isAfter(clock.instant())) && !coldStart) { + return; + } + // It is possible 2 states of the same model id gets saved: one overwrite another. + // This can happen if previous checkpoint hasn't been saved to disk, while the + // 1st one creates a new state without restoring. + if (modelState.getModel() != null) { + Map source = new HashMap<>(); + source.put(DETECTOR_ID, modelState.getDetectorId()); + source.put(FIELD_MODEL, toCheckpoint(modelState.getModel())); + source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + requests.add(new IndexRequest(indexName).id(modelId).source(source)); + modelState.setLastCheckpointTime(Instant.now()); + if (requests.size() >= maxBulkRequestSize) { + flush(); + } + } + } + /** * Returns the checkpoint for the model. * @@ -117,33 +345,24 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint, ActionLis @Deprecated public Optional getModelCheckpoint(String modelId) { return clientUtil - .timedRequest(new GetRequest(indexName, DOC_TYPE, modelId), logger, client::get) + .timedRequest(new GetRequest(indexName, modelId), logger, client::get) .filter(GetResponse::isExists) .map(GetResponse::getSource) .map(source -> (String) source.get(FIELD_MODEL)); } - /** - * Returns to listener the checkpoint for the model. - * - * @param modelId id of the model - * @param listener onResponse is called with the model checkpoint, or empty for no such model - */ - public void getModelCheckpoint(String modelId, ActionListener> listener) { - clientUtil - .asyncRequest( - new GetRequest(indexName, DOC_TYPE, modelId), - client::get, - ActionListener.wrap(response -> listener.onResponse(processModelCheckpoint(response)), listener::onFailure) - ); - } - - private Optional processModelCheckpoint(GetResponse response) { - return Optional - .ofNullable(response) - .filter(GetResponse::isExists) - .map(GetResponse::getSource) - .map(source -> (String) source.get(FIELD_MODEL)); + String toCheckpoint(EntityModel model) { + return AccessController.doPrivileged((PrivilegedAction) () -> { + JsonObject json = new JsonObject(); + json.add(ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); + if (model.getRcf() != null) { + json.addProperty(ENTITY_RCF, rcfSerde.toJson(model.getRcf())); + } + if (model.getThreshold() != null) { + json.addProperty(ENTITY_THRESHOLD, gson.toJson(model.getThreshold())); + } + return gson.toJson(json); + }); } /** @@ -155,7 +374,7 @@ private Optional processModelCheckpoint(GetResponse response) { */ @Deprecated public void deleteModelCheckpoint(String modelId) { - clientUtil.timedRequest(new DeleteRequest(indexName, DOC_TYPE, modelId), logger, client::delete); + clientUtil.timedRequest(new DeleteRequest(indexName, modelId), logger, client::delete); } /** @@ -167,9 +386,128 @@ public void deleteModelCheckpoint(String modelId) { public void deleteModelCheckpoint(String modelId, ActionListener listener) { clientUtil .asyncRequest( - new DeleteRequest(indexName, DOC_TYPE, modelId), + new DeleteRequest(indexName, modelId), client::delete, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) ); } + + /** + * Delete checkpoints associated with a detector. Used in multi-entity detector. + * @param detectorID Detector Id + */ + public void deleteModelCheckpointByDetectorId(String detectorID) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(CommonName.CHECKPOINT_INDEX_NAME) + .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of detector {}", detectorID); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, detectorID); + } + // if 0 docs get deleted, it means we cannot find matching docs + logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); + } + })); + } + + private void logFailure(BulkByScrollResponse response, String detectorID) { + if (response.isTimedOut()) { + logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + private Entry fromEntityModelCheckpoint(Map checkpoint, String modelId) { + try { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + String model = (String) (checkpoint.get(FIELD_MODEL)); + JsonObject json = JsonParser.parseString(model).getAsJsonObject(); + ArrayDeque samples = new ArrayDeque<>( + Arrays.asList(this.gson.fromJson(json.getAsJsonArray(ENTITY_SAMPLE), new double[0][0].getClass())) + ); + RandomCutForest rcf = null; + if (json.has(ENTITY_RCF)) { + rcf = rcfSerde.fromJson(json.getAsJsonPrimitive(ENTITY_RCF).getAsString()); + } + ThresholdingModel threshold = null; + if (json.has(ENTITY_THRESHOLD)) { + threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass); + } + + String lastCheckpointTimeString = (String) (checkpoint.get(TIMESTAMP)); + Instant timestamp = Instant.parse(lastCheckpointTimeString); + return new SimpleImmutableEntry<>(new EntityModel(modelId, samples, rcf, threshold), timestamp); + }); + } catch (RuntimeException e) { + logger.warn("Exception while deserializing checkpoint", e); + throw e; + } + } + + /** + * Read a checkpoint from the index and return the EntityModel object + * @param modelId Model Id + * @param listener Listener to return the EntityModel object + */ + public void restoreModelCheckpoint(String modelId, ActionListener>> listener) { + clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + listener.onResponse(Optional.of(fromEntityModelCheckpoint(checkpointString.get(), modelId))); + } else { + listener.onResponse(Optional.empty()); + } + }, listener::onFailure)); + } + + /** + * Returns to listener the checkpoint for the model. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getModelCheckpoint(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> listener.onResponse(processModelCheckpoint(response)), listener::onFailure) + ); + } + + private Optional processModelCheckpoint(GetResponse response) { + return Optional + .ofNullable(response) + .filter(GetResponse::isExists) + .map(GetResponse::getSource) + .map(source -> (String) source.get(FIELD_MODEL)); + } + + private Optional> processRawCheckpoint(GetResponse response) { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } } diff --git a/src/main/resources/mappings/checkpoint.json b/src/main/resources/mappings/checkpoint.json new file mode 100644 index 00000000..8fe8a099 --- /dev/null +++ b/src/main/resources/mappings/checkpoint.json @@ -0,0 +1,18 @@ +{ + "dynamic": true, + "_meta": { + "schema_version": 1 + }, + "properties": { + "detectorId": { + "type": "keyword" + }, + "model": { + "type": "text" + }, + "timestamp": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java index 618d1254..d5e77416 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDaoTests.java @@ -15,6 +15,7 @@ package com.amazon.opendistroforelasticsearch.ad.ml; +import static com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao.FIELD_MODEL; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -23,37 +24,75 @@ import static org.mockito.Matchers.anyObject; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.Month; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.function.BiConsumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.shard.ShardId; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; +import com.google.gson.Gson; +@RunWith(PowerMockRunner.class) +@PrepareForTest({ Gson.class }) public class CheckpointDaoTests { + private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); private CheckpointDao checkpointDao; @@ -67,6 +106,15 @@ public class CheckpointDaoTests { @Mock private GetResponse getResponse; + @Mock + private RandomCutForestSerDe rcfSerde; + + @Mock + private Clock clock; + + @Mock + private AnomalyDetectionIndices indexUtil; + // configuration private String indexName; @@ -75,24 +123,47 @@ public class CheckpointDaoTests { private String model; private Map docSource; + private Gson gson; + private Class thresholdingModelClass; + private int maxBulkSize; + @Before public void setup() { MockitoAnnotations.initMocks(this); indexName = "testIndexName"; - checkpointDao = new CheckpointDao(client, clientUtil, indexName); + gson = PowerMockito.mock(Gson.class); + + thresholdingModelClass = HybridThresholdingModel.class; + + when(clock.instant()).thenReturn(Instant.now()); + + maxBulkSize = 10; + + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + rcfSerde, + thresholdingModelClass, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + indexUtil, + maxBulkSize, + 200.0 + ); + + when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); modelId = "testModelId"; model = "testModel"; docSource = new HashMap<>(); - docSource.put(CheckpointDao.FIELD_MODEL, model); + docSource.put(FIELD_MODEL, model); } - @Test - public void putModelCheckpoint_getIndexRequest() { - checkpointDao.putModelCheckpoint(modelId, model); - + private void verifySuccessfulPutModelCheckpointSync() { ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); verify(clientUtil) .timedRequest( @@ -102,14 +173,65 @@ public void putModelCheckpoint_getIndexRequest() { ); IndexRequest indexRequest = indexRequestCaptor.getValue(); assertEquals(indexName, indexRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, indexRequest.type()); assertEquals(modelId, indexRequest.id()); - Set expectedSourceKeys = new HashSet(Arrays.asList(CheckpointDao.FIELD_MODEL, CheckpointDao.TIMESTAMP)); + Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODEL, CheckpointDao.TIMESTAMP)); assertEquals(expectedSourceKeys, indexRequest.sourceAsMap().keySet()); - assertEquals(model, indexRequest.sourceAsMap().get(CheckpointDao.FIELD_MODEL)); + assertEquals(model, indexRequest.sourceAsMap().get(FIELD_MODEL)); assertNotNull(indexRequest.sourceAsMap().get(CheckpointDao.TIMESTAMP)); } + @Test + public void putModelCheckpoint_getIndexRequest() { + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_no_checkpoint_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_index_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verifySuccessfulPutModelCheckpointSync(); + } + + @Test + public void putModelCheckpoint_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.putModelCheckpoint(modelId, model); + + verify(clientUtil, never()).timedRequest(any(), any(), any()); + } + @Test public void getModelCheckpoint_returnExpected() { ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); @@ -129,7 +251,6 @@ public void getModelCheckpoint_returnExpected() { assertEquals(model, result.get()); GetRequest getRequest = getRequestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); } @@ -158,13 +279,11 @@ public void deleteModelCheckpoint_getDeleteRequest() { ); DeleteRequest deleteRequest = deleteRequestCaptor.getValue(); assertEquals(indexName, deleteRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, deleteRequest.type()); assertEquals(modelId, deleteRequest.id()); } - @Test @SuppressWarnings("unchecked") - public void putModelCheckpoint_callListener_whenCompleted() { + private void verifyPutModelCheckpointAsync() { ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -177,11 +296,10 @@ public void putModelCheckpoint_callListener_whenCompleted() { IndexRequest indexRequest = requestCaptor.getValue(); assertEquals(indexName, indexRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, indexRequest.type()); assertEquals(modelId, indexRequest.id()); - Set expectedSourceKeys = new HashSet(Arrays.asList(CheckpointDao.FIELD_MODEL, CheckpointDao.TIMESTAMP)); + Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODEL, CheckpointDao.TIMESTAMP)); assertEquals(expectedSourceKeys, indexRequest.sourceAsMap().keySet()); - assertEquals(model, indexRequest.sourceAsMap().get(CheckpointDao.FIELD_MODEL)); + assertEquals(model, indexRequest.sourceAsMap().get(FIELD_MODEL)); assertNotNull(indexRequest.sourceAsMap().get(CheckpointDao.TIMESTAMP)); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); @@ -190,6 +308,54 @@ public void putModelCheckpoint_callListener_whenCompleted() { assertEquals(null, response); } + @Test + public void putModelCheckpoint_callListener_whenCompleted() { + verifyPutModelCheckpointAsync(); + } + + @Test + public void putModelCheckpoint_callListener_no_checkpoint_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + @Test + public void putModelCheckpoint_callListener_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + @SuppressWarnings("unchecked") + @Test + public void putModelCheckpoint_callListener_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + ActionListener listener = mock(ActionListener.class); + checkpointDao.putModelCheckpoint(modelId, model, listener); + + verify(clientUtil, never()).asyncRequest(any(), any(), any()); + } + @Test @SuppressWarnings("unchecked") public void getModelCheckpoint_returnExpectedToListener() { @@ -207,7 +373,6 @@ public void getModelCheckpoint_returnExpectedToListener() { GetRequest getRequest = requestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -232,7 +397,6 @@ public void getModelCheckpoint_returnEmptyToListener_whenModelNotFound() { GetRequest getRequest = requestCaptor.getValue(); assertEquals(indexName, getRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, getRequest.type()); assertEquals(modelId, getRequest.id()); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); @@ -255,7 +419,6 @@ public void deleteModelCheckpoint_callListener_whenCompleted() { DeleteRequest deleteRequest = requestCaptor.getValue(); assertEquals(indexName, deleteRequest.index()); - assertEquals(CheckpointDao.DOC_TYPE, deleteRequest.type()); assertEquals(modelId, deleteRequest.id()); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); @@ -263,4 +426,283 @@ public void deleteModelCheckpoint_callListener_whenCompleted() { Void response = responseCaptor.getValue(); assertEquals(null, response); } + + private BulkResponse createBulkResponse(int succeeded, int failed, String[] failedId) { + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[succeeded + failed]; + + ShardId shardId = new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "", 1); + int i = 0; + for (; i < failed; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new BulkItemResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + CommonName.MAPPING_TYPE, + failedId[i], + new VersionConflictEngineException(shardId, "id", "test") + ) + ); + } + + for (; i < failed + succeeded; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new UpdateResponse(shardId, CommonName.MAPPING_TYPE, "1", 0L, 1L, 1L, DocWriteResponse.Result.CREATED) + ); + } + + return new BulkResponse(bulkItemResponses, 507); + } + + @SuppressWarnings("unchecked") + @Test + public void flush_less_than_1k() { + int writeRequests = maxBulkSize - 1; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + public void flush_more_than_1k() { + int writeRequests = maxBulkSize + 1; + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(maxBulkSize, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + // should trigger auto flush + checkpointDao.write(state, state.getModelId(), true); + } + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @Test + public void flush_more_than_1k_has_index() { + flush_more_than_1k(); + } + + @Test + public void flush_more_than_1k_no_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + flush_more_than_1k(); + } + + @Test + public void flush_more_than_1k_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + flush_more_than_1k(); + } + + @SuppressWarnings("unchecked") + @Test + public void flush_more_than_1k_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void bulk_has_failure() throws InterruptedException { + int writeRequests = maxBulkSize - 1; + int failureCount = 1; + String[] failedId = new String[failureCount]; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + if (i < failureCount) { + failedId[i] = state.getModelId(); + } + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), failureCount, failedId)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(failureCount, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void bulk_all_failure() throws InterruptedException { + int writeRequests = maxBulkSize - 1; + for (int i = 0; i < writeRequests; i++) { + ModelState state = MLUtil.randomModelState(); + checkpointDao.write(state, state.getModelId(), true); + } + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onFailure(new RuntimeException("")); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + doAnswer(invocation -> { + BulkRequest request = invocation.getArgument(1); + assertEquals(writeRequests, request.numberOfActions()); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + checkpointDao.flush(); + + verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void checkpoint_saved_less_than_1_hr() { + ModelState state = MLUtil.randomModelState(); + state.setLastCheckpointTime(Instant.now()); + checkpointDao.write(state, state.getModelId()); + + checkpointDao.flush(); + + verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void checkpoint_coldstart_checkpoint() { + ModelState state = MLUtil.randomModelState(); + state.setLastCheckpointTime(Instant.now()); + // cold start checkpoint will save whatever + checkpointDao.write(state, state.getModelId(), true); + + checkpointDao.flush(); + + verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void restore() throws IOException { + ModelState state = MLUtil.randomNonEmptyModelState(); + EntityModel modelToSave = state.getModel(); + + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + new Gson(), + new RandomCutForestSerDe(), + thresholdingModelClass, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + indexUtil, + maxBulkSize, + 2 + ); + + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + Map source = new HashMap<>(); + source.put(CheckpointDao.DETECTOR_ID, state.getDetectorId()); + source.put(CheckpointDao.FIELD_MODEL, checkpointDao.toCheckpoint(modelToSave)); + source.put(CheckpointDao.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); + when(getResponse.getSource()).thenReturn(source); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(getResponse); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); + + ActionListener>> listener = mock(ActionListener.class); + checkpointDao.restoreModelCheckpoint(modelId, listener); + + ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional> response = responseCaptor.getValue(); + assertTrue(response.isPresent()); + Entry entry = response.get(); + OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); + assertEquals(2020, utcTime.getYear()); + assertEquals(Month.OCTOBER, utcTime.getMonth()); + assertEquals(11, utcTime.getDayOfMonth()); + assertEquals(22, utcTime.getHour()); + assertEquals(58, utcTime.getMinute()); + assertEquals(23, utcTime.getSecond()); + + EntityModel model = entry.getKey(); + Queue queue = model.getSamples(); + Queue samplesToSave = modelToSave.getSamples(); + assertEquals(samplesToSave.size(), queue.size()); + assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); + logger.info(modelToSave.getRcf()); + logger.info(model.getRcf()); + assertEquals(modelToSave.getRcf().getTotalUpdates(), model.getRcf().getTotalUpdates()); + assertTrue(model.getThreshold() != null); + } } From 55e01fa9864b929888d84dc4f50977b9ca072a04 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 14 Oct 2020 15:32:27 -0700 Subject: [PATCH 2/5] Remove debug message --- .../opendistroforelasticsearch/ad/ml/CheckpointDao.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java index ffc5b6b5..8446c2cd 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java @@ -280,9 +280,7 @@ public void flush() { private void flush(BulkRequest bulkRequest) { clientUtil.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(r -> { - if (r.hasFailures() == false) { - logger.debug("Succeeded in bulking checkpoints"); - } else { + if (r.hasFailures()) { requests.addAll(BulkUtil.getIndexRequestToRetry(bulkRequest, r)); } }, e -> { From 9af2a95ba97810beae35a2904f52e4738dd9004a Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 14 Oct 2020 20:03:22 -0700 Subject: [PATCH 3/5] Add supporting classes --- .../ad/util/BulkUtil.java | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java new file mode 100644 index 00000000..23a35057 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/BulkUtil.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.util; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; + +public class BulkUtil { + private static final Logger logger = LogManager.getLogger(BulkUtil.class); + + public static List> getIndexRequestToRetry(BulkRequest bulkRequest, BulkResponse bulkResponse) { + List> res = new ArrayList<>(); + + Set failedId = new HashSet<>(); + for (BulkItemResponse response : bulkResponse.getItems()) { + if (response.isFailed()) { + failedId.add(response.getId()); + } + } + + for (DocWriteRequest request : bulkRequest.requests()) { + if (failedId.contains(request.id())) { + res.add(request); + } + } + return res; + } +} From 4177fda5f21799d2d25ec84a6ed63f1c98f8ac78 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 15 Oct 2020 13:35:37 -0700 Subject: [PATCH 4/5] refactor code --- .../ad/ml/CheckpointDao.java | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java index 8446c2cd..34986e8d 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java @@ -171,20 +171,7 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint) { if (indexUtil.doesCheckpointIndexExist()) { saveModelCheckpointSync(source, modelId); } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - saveModelCheckpointSync(source, modelId); - } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - saveModelCheckpointSync(source, modelId); - } else { - logger.error(String.format("Unexpected error creating index %s", indexName), exception); - } - })); + onCheckpointNotExist(source, modelId, false, null); } } @@ -206,21 +193,33 @@ public void putModelCheckpoint(String modelId, String modelCheckpoint, ActionLis if (indexUtil.doesCheckpointIndexExist()) { saveModelCheckpointAsync(source, modelId, listener); } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { + onCheckpointNotExist(source, modelId, true, listener); + } + } + + private void onCheckpointNotExist(Map source, String modelId, boolean isAsync, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + if (isAsync) { saveModelCheckpointAsync(source, modelId, listener); } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + saveModelCheckpointSync(source, modelId); } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + if (isAsync) { saveModelCheckpointAsync(source, modelId, listener); } else { - logger.error(String.format("Unexpected error creating index %s", indexName), exception); + saveModelCheckpointSync(source, modelId); } - })); - } + } else { + logger.error(String.format("Unexpected error creating index %s", indexName), exception); + } + })); } private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { @@ -325,7 +324,7 @@ public void write(ModelState modelState, String modelId, boolean co source.put(FIELD_MODEL, toCheckpoint(modelState.getModel())); source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); requests.add(new IndexRequest(indexName).id(modelId).source(source)); - modelState.setLastCheckpointTime(Instant.now()); + modelState.setLastCheckpointTime(clock.instant()); if (requests.size() >= maxBulkRequestSize) { flush(); } From 477bb20fb1dfa3f83590e79f7226692284ef76b1 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 15 Oct 2020 15:17:15 -0700 Subject: [PATCH 5/5] Update comment indicating we are changinge argument --- .../opendistroforelasticsearch/ad/ml/CheckpointDao.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java index 34986e8d..a643b29e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.java @@ -304,7 +304,9 @@ public void write(ModelState modelState, String modelId) { /** * Prepare bulking the input model state to the checkpoint index. * We don't save checkpoints within checkpointInterval again, except this - * is from cold start. + * is from cold start. This method will update the input state's last + * checkpoint time if the checkpoint is staged (ready to be written in the + * next batch). * @param modelState Model state * @param modelId Model Id * @param coldStart whether the checkpoint comes from cold start