Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Add async clear #91

Merged
merged 1 commit into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -556,6 +557,39 @@ public void clear(String detectorId) {
clearModels(detectorId, thresholds);
}

/**
* Permanently deletes models hosted in memory and persisted in index.
*
* @param detectorId id the of the detector for which models are to be permanently deleted
* @param listener onResponse is called with null when this operation is completed
*/
public void clear(String detectorId, ActionListener<Void> listener) {
clearModels(detectorId, forests, ActionListener.wrap(r -> clearModels(detectorId, thresholds, listener), listener::onFailure));
}

private void clearModels(String detectorId, Map<String, ?> models, ActionListener<Void> listener) {
Iterator<String> id = models.keySet().iterator();
clearModelForIterator(detectorId, models, id, listener);
}

private void clearModelForIterator(String detectorId, Map<String, ?> models, Iterator<String> idIter, ActionListener<Void> listener) {
if (idIter.hasNext()) {
String modelId = idIter.next();
if (getDetectorIdForModelId(modelId).equals(detectorId)) {
models.remove(modelId);
checkpointDao
.deleteModelCheckpoint(
modelId,
ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure)
);
} else {
clearModelForIterator(detectorId, models, idIter, listener);
}
} else {
listener.onResponse(null);
}
}

/**
* Trains and saves cold-start AD models.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ public void estimateModelSize_returnExpected(RandomCutForest rcf, long expectedS

@Test
public void getRcfResult_returnExpected() {
String checkpoint = "testCheckpoint";
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);

Expand Down Expand Up @@ -363,7 +362,6 @@ public void getRcfResult_throwResourceNotFound_whenNoModelCheckpointFound() {
public void getRcfResult_throwLimitExceeded_whenHeapLimitReached() {
String detectorId = "testDetectorId";
String modelId = "testModelId";
String checkpoint = "testCheckpoint";

when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(rcf);
Expand Down Expand Up @@ -443,8 +441,6 @@ public void getThresholdingResult_returnExpected() {
String modelId = "testModelId";
double score = 1.;

String checkpoint = "testCheckpoint";

double grade = 0.;
double confidence = 0.5;

Expand Down Expand Up @@ -514,7 +510,6 @@ public void getThresholdingResult_throwToListener_withNoCheckpoint() {

@Test
public void getAllModelIds_returnAllIds_forRcfAndThreshold() {
String checkpoint = "checkpoint";

when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(mock(RandomCutForest.class));
Expand All @@ -533,7 +528,6 @@ public void getAllModelIds_returnEmpty_forNoModels() {

@Test
public void stopModel_saveRcfCheckpoint() {
String checkpoint = "checkpoint";

RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
Expand All @@ -549,7 +543,6 @@ public void stopModel_saveRcfCheckpoint() {

@Test
public void stopModel_saveThresholdCheckpoint() {
String checkpoint = "checkpoint";

when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint));
PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
Expand All @@ -564,7 +557,6 @@ public void stopModel_saveThresholdCheckpoint() {

@Test
public void clear_deleteRcfCheckpoint() {
String checkpoint = "checkpoint";

RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
Expand All @@ -578,7 +570,6 @@ public void clear_deleteRcfCheckpoint() {

@Test
public void clear_deleteThresholdCheckpoint() {
String checkpoint = "checkpoint";

when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint));
PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
Expand All @@ -590,6 +581,66 @@ public void clear_deleteThresholdCheckpoint() {
verify(checkpointDao).deleteModelCheckpoint(thresholdModelId);
}

@Test
@SuppressWarnings("unchecked")
public void clear_callListener_whenRcfDeleted() {
String otherModelId = detectorId + rcfModelId;
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(checkpointDao.getModelCheckpoint(otherModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
modelManager.getRcfResult(otherModelId, otherModelId, new double[0]);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(1);
listener.onResponse(null);
return null;
}).when(checkpointDao).deleteModelCheckpoint(eq(rcfModelId), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.clear(detectorId, listener);

verify(listener).onResponse(null);
}

@Test
@SuppressWarnings("unchecked")
public void clear_callListener_whenThresholdDeleted() {
when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint));
PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
PowerMockito.doReturn(checkpoint).when(gson).toJson(hybridThresholdingModel);
modelManager.getThresholdingResult(detectorId, thresholdModelId, 0);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(1);
listener.onResponse(null);
return null;
}).when(checkpointDao).deleteModelCheckpoint(eq(thresholdModelId), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.clear(detectorId, listener);

verify(listener).onResponse(null);
}

@Test
@SuppressWarnings("unchecked")
public void clear_throwToListener_whenDeleteFail() {
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).deleteModelCheckpoint(eq(rcfModelId), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.clear(detectorId, listener);

verify(listener).onFailure(any(Exception.class));
}

@Test
public void trainModel_putTrainedModels() {
double[][] trainData = new Random().doubles().limit(100).mapToObj(d -> new double[] { d }).toArray(double[][]::new);
Expand Down Expand Up @@ -725,7 +776,6 @@ public void maintenance_saveThresholdCheckpoint_skippingFailure() {
@Test
public void maintenance_stopInactiveRcfModel() {
String modelId = "testModelId";
String checkpoint = "testCheckpoint";
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint));
Expand All @@ -743,7 +793,6 @@ public void maintenance_stopInactiveRcfModel() {
@Test
public void maintenance_keepActiveRcfModel() {
String modelId = "testModelId";
String checkpoint = "testCheckpoint";
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint));
Expand All @@ -761,7 +810,6 @@ public void maintenance_keepActiveRcfModel() {
@Test
public void maintenance_stopInactiveThresholdModel() {
String modelId = "testModelId";
String checkpoint = "testCheckpoint";
double score = 1.;
when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint));
doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
Expand All @@ -778,7 +826,6 @@ public void maintenance_stopInactiveThresholdModel() {
@Test
public void maintenance_keepActiveThresholdModel() {
String modelId = "testModelId";
String checkpoint = "testCheckpoint";
double score = 1.;
when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint));
doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
Expand Down