Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
add async trainModel (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored Apr 10, 2020
1 parent 9af0080 commit f851f1a
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,116 @@ public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints) {
checkpointDao.putModelCheckpoint(modelId, checkpoint);
}

/**
* Trains and saves cold-start AD models.
*
* This implementations splits RCF models and trains them all.
* As all model partitions have the same size, the scores from RCF models are merged by averaging.
* Since RCF outputs 0 until it is ready, initial 0 scores are meaningless and therefore filtered out.
* Filtered (non-zero) RCF scores are the training data for a single thresholding model.
* All trained models are serialized and persisted to be hosted.
*
* @param anomalyDetector the detector for which models are trained
* @param dataPoints M, N shape, where M is the number of samples for training and N is the number of features
* @param listener onResponse is called with null when this operation is completed
* onFailure is called IllegalArgumentException when training data is invalid
* onFailure is called LimitExceededException when a limit for training is exceeded
*/
public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints, ActionListener<Void> listener) {
if (dataPoints.length == 0 || dataPoints[0].length == 0) {
listener.onFailure(new IllegalArgumentException("Data points must not be empty."));
} else {
int rcfNumFeatures = dataPoints[0].length;
// creates partitioned RCF models
try {
Entry<Integer, Integer> partitionResults = getPartitionedForestSizes(
RandomCutForest
.builder()
.dimensions(rcfNumFeatures)
.sampleSize(rcfNumSamplesInTree)
.numberOfTrees(rcfNumTrees)
.outputAfter(rcfNumSamplesInTree)
.parallelExecutionEnabled(false)
.build(),
anomalyDetector.getDetectorId()
);
int numForests = partitionResults.getKey();
int forestSize = partitionResults.getValue();
double[] scores = new double[dataPoints.length];
Arrays.fill(scores, 0.);
trainModelForStep(anomalyDetector, dataPoints, rcfNumFeatures, numForests, forestSize, scores, 0, listener);
} catch (LimitExceededException e) {
listener.onFailure(e);
}
}
}

private void trainModelForStep(
AnomalyDetector detector,
double[][] dataPoints,
int rcfNumFeatures,
int numForests,
int forestSize,
final double[] scores,
int step,
ActionListener<Void> listener
) {
if (step < numForests) {
RandomCutForest rcf = RandomCutForest
.builder()
.dimensions(rcfNumFeatures)
.sampleSize(rcfNumSamplesInTree)
.numberOfTrees(forestSize)
.lambda(rcfTimeDecay)
.outputAfter(rcfNumSamplesInTree)
.parallelExecutionEnabled(false)
.build();
for (int j = 0; j < dataPoints.length; j++) {
scores[j] += rcf.getAnomalyScore(dataPoints[j]);
rcf.update(dataPoints[j]);
}
String modelId = getRcfModelId(detector.getDetectorId(), step);
String checkpoint = AccessController.doPrivileged((PrivilegedAction<String>) () -> rcfSerde.toJson(rcf));
checkpointDao
.putModelCheckpoint(
modelId,
checkpoint,
ActionListener
.wrap(
r -> trainModelForStep(
detector,
dataPoints,
rcfNumFeatures,
numForests,
forestSize,
scores,
step + 1,
listener
),
listener::onFailure
)
);
} else {
double[] rcfScores = DoubleStream.of(scores).filter(score -> score > 0).map(score -> score / numForests).toArray();

// Train thresholding model
ThresholdingModel threshold = new HybridThresholdingModel(
thresholdMinPvalue,
thresholdMaxRankError,
thresholdMaxScore,
thresholdNumLogNormalQuantiles,
thresholdDownsamples,
thresholdMaxSamples
);
threshold.train(rcfScores);

// Persist thresholding model
String modelId = getThresholdModelId(detector.getDetectorId());
String checkpoint = AccessController.doPrivileged((PrivilegedAction<String>) () -> gson.toJson(threshold));
checkpointDao.putModelCheckpoint(modelId, checkpoint, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure));
}
}

/**
* Returns the model ID for the RCF model partition.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,45 @@ public void trainModel_throwIllegalArgument_forInvalidInput(double[][] trainData
modelManager.trainModel(anomalyDetector, trainData);
}

@Test
@SuppressWarnings("unchecked")
public void trainModel_returnExpectedToListener_putCheckpoints() {
double[][] trainData = new Random().doubles().limit(100).mapToObj(d -> new double[] { d }).toArray(double[][]::new);
doReturn(new SimpleEntry<>(2, 10)).when(modelManager).getPartitionedForestSizes(anyObject(), anyObject());
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(any(), any(), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.trainModel(anomalyDetector, trainData, listener);

verify(listener).onResponse(eq(null));
verify(checkpointDao, times(3)).putModelCheckpoint(any(), any(), any());
}

@Test
@SuppressWarnings("unchecked")
@Parameters(method = "trainModelIllegalArgumentData")
public void trainModel_throwIllegalArgumentToListener_forInvalidTrainData(double[][] trainData) {
ActionListener<Void> listener = mock(ActionListener.class);
modelManager.trainModel(anomalyDetector, trainData, listener);

verify(listener).onFailure(any(IllegalArgumentException.class));
}

@Test
@SuppressWarnings("unchecked")
public void trainModel_throwLimitExceededToListener_whenLimitExceed() {
doThrow(new LimitExceededException(null, null)).when(modelManager).getPartitionedForestSizes(anyObject(), anyObject());

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.trainModel(anomalyDetector, new double[][] { { 0 } }, listener);

verify(listener).onFailure(any(LimitExceededException.class));
}

@Test
public void getRcfModelId_returnNonEmptyString() {
String rcfModelId = modelManager.getRcfModelId(anomalyDetector.getDetectorId(), 0);
Expand Down

0 comments on commit f851f1a

Please sign in to comment.