Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
add async maintenance (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored May 7, 2020
1 parent 5bc0ccd commit c8f4f96
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,56 @@ private <T> void maintenance(Map<String, ModelState<T>> models, Function<T, Stri
});
}

/**
* Does model maintenance.
*
* The implementation makes checkpoints for hosted models and stops hosting models not recently used.
*
* @param listener onResponse is called with null when this operation is completed.
*/
public void maintenance(ActionListener<Void> listener) {
maintenanceForIterator(
forests,
this::toCheckpoint,
forests.entrySet().iterator(),
ActionListener
.wrap(
r -> maintenanceForIterator(thresholds, this::toCheckpoint, thresholds.entrySet().iterator(), listener),
listener::onFailure
)
);
}

private <T> void maintenanceForIterator(
Map<String, ModelState<T>> models,
Function<T, String> toCheckpoint,
Iterator<Entry<String, ModelState<T>>> iter,
ActionListener<Void> listener
) {
if (iter.hasNext()) {
Entry<String, ModelState<T>> modelEntry = iter.next();
String modelId = modelEntry.getKey();
ModelState<T> modelState = modelEntry.getValue();
Instant now = clock.instant();
if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) {
models.remove(modelId);
}
if (modelState.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)) {
checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(modelState.getModel()), ActionListener.wrap(r -> {
modelState.setLastCheckpointTime(now);
maintenanceForIterator(models, toCheckpoint, iter, listener);
}, e -> {
logger.warn("Failed to finish maintenance for model id " + modelId, e);
maintenanceForIterator(models, toCheckpoint, iter, listener);
}));
} else {
maintenanceForIterator(models, toCheckpoint, iter, listener);
}
} else {
listener.onResponse(null);
}
}

/**
* Returns computed anomaly results for preview data points.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,162 @@ public void maintenance_skipCheckpoint_whenLastCheckpointIsRecent() {
verify(checkpointDao, times(1)).putModelCheckpoint(eq(modelId), anyObject());
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_forRcfModel() {
String successModelId = "testSuccessModelId";
String failModelId = "testFailModelId";
String successCheckpoint = "testSuccessCheckpoint";
String failCheckpoint = "testFailCheckpoint";
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
RandomCutForest failForest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(successCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(failCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
when(rcfSerde.fromJson(successCheckpoint)).thenReturn(forest);
when(rcfSerde.fromJson(failCheckpoint)).thenReturn(failForest);
when(rcfSerde.toJson(forest)).thenReturn(successCheckpoint);
when(rcfSerde.toJson(failForest)).thenReturn(failCheckpoint);
when(clock.instant()).thenReturn(Instant.EPOCH);
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, successModelId, point, scoreListener);
modelManager.getRcfResult(detectorId, failModelId, point, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);

verify(listener).onResponse(eq(null));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_forThresholdModel() {
String successModelId = "testSuccessModelId";
String failModelId = "testFailModelId";
String successCheckpoint = "testSuccessCheckpoint";
String failCheckpoint = "testFailCheckpoint";
double score = 1.;
HybridThresholdingModel failThresholdModel = mock(HybridThresholdingModel.class);
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(successCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(successModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(failCheckpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(failModelId), eq(failCheckpoint), any(ActionListener.class));
doReturn(hybridThresholdingModel).when(gson).fromJson(successCheckpoint, thresholdingModelClass);
doReturn(failThresholdModel).when(gson).fromJson(failCheckpoint, thresholdingModelClass);
doReturn(successCheckpoint).when(gson).toJson(hybridThresholdingModel);
doThrow(new RuntimeException()).when(gson).toJson(failThresholdModel);
when(clock.instant()).thenReturn(Instant.EPOCH);
ActionListener<ThresholdingResult> scoreListener = mock(ActionListener.class);
modelManager.getThresholdingResult(detectorId, successModelId, score, scoreListener);
modelManager.getThresholdingResult(detectorId, failModelId, score, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);

verify(listener).onResponse(eq(null));
verify(checkpointDao, times(1)).putModelCheckpoint(eq(successModelId), eq(successCheckpoint), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_stopModel() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
when(clock.instant()).thenReturn(Instant.EPOCH, Instant.EPOCH, Instant.EPOCH.plus(modelTtl.plusSeconds(1)));
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
verify(checkpointDao, times(2)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
}

@Test
@SuppressWarnings("unchecked")
public void maintenance_returnExpectedToListener_doNothing() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
when(clock.instant()).thenReturn(Instant.MIN);
ActionListener<RcfResult> scoreListener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
ActionListener<Void> listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

listener = mock(ActionListener.class);
modelManager.maintenance(listener);
verify(listener).onResponse(eq(null));

modelManager.getRcfResult(detectorId, rcfModelId, point, scoreListener);
verify(checkpointDao, times(1)).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
}

@Test
public void getPreviewResults_returnNoAnomalies_forNoAnomalies() {
int numPoints = 1000;
Expand Down

0 comments on commit c8f4f96

Please sign in to comment.