Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

add async maintenance #94

Merged
merged 1 commit into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -825,6 +826,56 @@ private <T> void maintenance(Map<String, ModelState<T>> models, Function<T, Stri
});
}

/**
* Does model maintenance.
*
* The implementation makes checkpoints for hosted models and stops hosting models not recently used.
*
* @param listener onResponse is called with null when this operation is completed.
*/
public void maintenance(ActionListener<Void> listener) {
maintenanceForIterator(
forests,
this::toCheckpoint,
forests.entrySet().iterator(),
ActionListener
.wrap(
r -> maintenanceForIterator(thresholds, this::toCheckpoint, thresholds.entrySet().iterator(), listener),
listener::onFailure
)
);
}

private <T> void maintenanceForIterator(
Map<String, ModelState<T>> models,
Function<T, String> toCheckpoint,
Iterator<Entry<String, ModelState<T>>> iter,
ActionListener<Void> listener
) {
if (iter.hasNext()) {
Entry<String, ModelState<T>> modelEntry = iter.next();
String modelId = modelEntry.getKey();
ModelState<T> modelState = modelEntry.getValue();
Instant now = clock.instant();
if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) {
models.remove(modelId);
}
if (modelState.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)) {
checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(modelState.getModel()), ActionListener.wrap(r -> {
modelState.setLastCheckpointTime(now);
maintenanceForIterator(models, toCheckpoint, iter, listener);
}, e -> {
logger.warn("Failed to finish maintenance for model id " + modelId, e);
maintenanceForIterator(models, toCheckpoint, iter, listener);
}));
} else {
maintenanceForIterator(models, toCheckpoint, iter, listener);
}
} else {
listener.onResponse(null);
}
}

/**
* Returns computed anomaly results for preview data points.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,162 @@ public void maintenance_skipCheckpoint_whenLastCheckpointIsRecent() {
verify(checkpointDao, times(1)).putModelCheckpoint(eq(modelId), anyObject());
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_forRcfModel() {
String successModelId = "testSuccessModelId";
String failModelId = "testFailModelId";
String successCheckpoint = "testSuccessCheckpoint";
String failCheckpoint = "testFailCheckpoint";
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
RandomCutForest failForest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(successCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(failCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
when(rcfSerde.fromJson(successCheckpoint)).thenReturn(forest);
when(rcfSerde.fromJson(failCheckpoint)).thenReturn(failForest);
when(rcfSerde.toJson(forest)).thenReturn(successCheckpoint);
when(rcfSerde.toJson(failForest)).thenReturn(failCheckpoint);
when(clock.instant()).thenReturn(Instant.EPOCH);
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, successModelId, point, scoreListener);
modelManager.getRcfResult(detectorId, failModelId, point, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);

verify(listener).onResponse(eq(null));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_forThresholdModel() {
String successModelId = "testSuccessModelId";
String failModelId = "testFailModelId";
String successCheckpoint = "testSuccessCheckpoint";
String failCheckpoint = "testFailCheckpoint";
double score = 1.;
HybridThresholdingModel failThresholdModel = mock(HybridThresholdingModel.class);
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(successCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(failCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
doReturn(hybridThresholdingModel).when(gson).fromJson(successCheckpoint, thresholdingModelClass);
doReturn(failThresholdModel).when(gson).fromJson(failCheckpoint, thresholdingModelClass);
doReturn(successCheckpoint).when(gson).toJson(hybridThresholdingModel);
doThrow(new RuntimeException()).when(gson).toJson(failThresholdModel);
when(clock.instant()).thenReturn(Instant.EPOCH);
ActionListener<ThresholdingResult> scoreListener = mock(ActionListener.class);
modelManager.getThresholdingResult(detectorId, successModelId, score, scoreListener);
modelManager.getThresholdingResult(detectorId, failModelId, score, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);

verify(listener).onResponse(eq(null));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_stopModel() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
when(clock.instant()).thenReturn(Instant.EPOCH, Instant.EPOCH, Instant.EPOCH.plus(modelTtl.plusSeconds(1)));
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
verify(checkpointDao, times(2)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_doNothing() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
when(clock.instant()).thenReturn(Instant.MIN);
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
verify(checkpointDao, times(1)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
}

@Test
public void getPreviewResults_returnNoAnomalies_forNoAnomalies() {
int numPoints = 1000;
Expand Down