From 85872394cf57cd5d6ff9e9cc58d6aae730d683d7 Mon Sep 17 00:00:00 2001 From: Lai Date: Tue, 28 Apr 2020 13:43:29 -0700 Subject: [PATCH] add async maintenance --- .../ad/ml/ModelManager.java | 51 ++++++ .../ad/ml/ModelManagerTests.java | 156 ++++++++++++++++++ 2 files changed, 207 insertions(+) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index e334645e..1cb653b3 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -825,6 +826,56 @@ private void maintenance(Map> models, Function listener) { + maintenanceForIterator( + forests, + this::toCheckpoint, + forests.entrySet().iterator(), + ActionListener + .wrap( + r -> maintenanceForIterator(thresholds, this::toCheckpoint, thresholds.entrySet().iterator(), listener), + listener::onFailure + ) + ); + } + + private void maintenanceForIterator( + Map> models, + Function toCheckpoint, + Iterator>> iter, + ActionListener listener + ) { + if (iter.hasNext()) { + Entry> modelEntry = iter.next(); + String modelId = modelEntry.getKey(); + ModelState modelState = modelEntry.getValue(); + Instant now = clock.instant(); + if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) { + models.remove(modelId); + } + if (modelState.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)) { + checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(modelState.getModel()), ActionListener.wrap(r -> { + modelState.setLastCheckpointTime(now); + maintenanceForIterator(models, toCheckpoint, iter, listener); + }, e -> { + logger.warn("Failed to finish maintenance for model id " + modelId, e); + maintenanceForIterator(models, toCheckpoint, iter, listener); + })); + } else { + maintenanceForIterator(models, toCheckpoint, iter, listener); + } + } else { + listener.onResponse(null); + } + } + /** * Returns computed anomaly results for preview data points. * diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 81242a49..b56b904f 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -810,6 +810,162 @@ public void maintenance_skipCheckpoint_whenLastCheckpointIsRecent() { verify(checkpointDao, times(1)).putModelCheckpoint(eq(modelId), anyObject()); } + @Test + @SuppressWarnings("unchecked") + public void maintenance_returnExpectedToListener_forRcfModel() { + String successModelId = "testSuccessModelId"; + String failModelId = "testFailModelId"; + String successCheckpoint = "testSuccessCheckpoint"; + String failCheckpoint = "testFailCheckpoint"; + double[] point = new double[0]; + RandomCutForest forest = mock(RandomCutForest.class); + RandomCutForest failForest = mock(RandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(successCheckpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(failCheckpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class)); + when(rcfSerde.fromJson(successCheckpoint)).thenReturn(forest); + when(rcfSerde.fromJson(failCheckpoint)).thenReturn(failForest); + when(rcfSerde.toJson(forest)).thenReturn(successCheckpoint); + when(rcfSerde.toJson(failForest)).thenReturn(failCheckpoint); + when(clock.instant()).thenReturn(Instant.EPOCH); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getRcfResult(detectorId, successModelId, point, scoreListener); + modelManager.getRcfResult(detectorId, failModelId, point, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + + verify(listener).onResponse(eq(null)); + verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class)); + verify(checkpointDao, times(1)).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void maintenance_returnExpectedToListener_forThresholdModel() { + String successModelId = "testSuccessModelId"; + String failModelId = "testFailModelId"; + String successCheckpoint = "testSuccessCheckpoint"; + String failCheckpoint = "testFailCheckpoint"; + double score = 1.; + HybridThresholdingModel failThresholdModel = mock(HybridThresholdingModel.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(successCheckpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(failCheckpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class)); + doReturn(hybridThresholdingModel).when(gson).fromJson(successCheckpoint, thresholdingModelClass); + doReturn(failThresholdModel).when(gson).fromJson(failCheckpoint, thresholdingModelClass); + doReturn(successCheckpoint).when(gson).toJson(hybridThresholdingModel); + doThrow(new RuntimeException()).when(gson).toJson(failThresholdModel); + when(clock.instant()).thenReturn(Instant.EPOCH); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, successModelId, score, scoreListener); + modelManager.getThresholdingResult(detectorId, failModelId, score, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + + verify(listener).onResponse(eq(null)); + verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void maintenance_returnExpectedToListener_stopModel() { + double[] point = new double[0]; + RandomCutForest forest = mock(RandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(checkpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class)); + when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); + when(rcfSerde.toJson(forest)).thenReturn(checkpoint); + when(clock.instant()).thenReturn(Instant.EPOCH, Instant.EPOCH, Instant.EPOCH.plus(modelTtl.plusSeconds(1))); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener); + verify(checkpointDao, times(2)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void maintenance_returnExpectedToListener_doNothing() { + double[] point = new double[0]; + RandomCutForest forest = mock(RandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(checkpoint)); + return null; + }).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class)); + when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); + when(rcfSerde.toJson(forest)).thenReturn(checkpoint); + when(clock.instant()).thenReturn(Instant.MIN); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener); + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener); + verify(checkpointDao, times(1)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + } + @Test public void getPreviewResults_returnNoAnomalies_forNoAnomalies() { int numPoints = 1000;