Skip to content

Commit

Permalink
add thresholded rcf (#215)
Browse files Browse the repository at this point in the history
Signed-off-by: lai <[email protected]>
  • Loading branch information
wnbts authored Sep 16, 2021
1 parent 087b3ba commit b0d8330
Show file tree
Hide file tree
Showing 19 changed files with 565 additions and 657 deletions.
4 changes: 3 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,13 @@ dependencies {
// implementation scope let the dependency in both compiling and running classpath, but
// not leaked through to clients (Opensearch). Here we force the jackson version to whatever
// opensearch uses.
compile 'software.amazon.randomcutforest:randomcutforest-core:2.0.1'
//compile 'software.amazon.randomcutforest:randomcutforest-core:2.0.1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:2.0.1'
implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}"
implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}"
implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}"
compile files('lib/randomcutforest-parkservices-2.0.1.jar')
compile files('lib/randomcutforest-core-2.0.1.jar')

// used for serializing/deserializing rcf models.
compile group: 'io.protostuff', name: 'protostuff-core', version: '1.7.4'
Expand Down
Binary file added lib/randomcutforest-core-2.0.1.jar
Binary file not shown.
Binary file added lib/randomcutforest-parkservices-2.0.1.jar
Binary file not shown.
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
Expand Down Expand Up @@ -450,6 +452,12 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
mapper,
schema,
converter,
new ThresholdedRandomCutForestMapper(),
AccessController
.doPrivileged(
(PrivilegedAction<Schema<ThresholdedRandomCutForestState>>) () -> RuntimeSchema
.getSchema(ThresholdedRandomCutForestState.class)
),
HybridThresholdingModel.class,
anomalyDetectionIndices,
AnomalyDetectorSettings.MAX_CHECKPOINT_BYTES,
Expand Down
114 changes: 105 additions & 9 deletions src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
Expand Down Expand Up @@ -116,6 +119,7 @@ public class CheckpointDao {
public static final String ENTITY_SAMPLE = "sp";
public static final String ENTITY_RCF = "rcf";
public static final String ENTITY_THRESHOLD = "th";
public static final String ENTITY_TRCF = "trcf";
public static final String FIELD_MODEL = "model";
public static final String FIELD_MODELV2 = "modelV2";
public static final String TIMESTAMP = "timestamp";
Expand All @@ -132,6 +136,8 @@ public class CheckpointDao {
private RandomCutForestMapper mapper;
private Schema<RandomCutForestState> schema;
private V1JsonToV2StateConverter converter;
private ThresholdedRandomCutForestMapper trcfMapper;
private Schema<ThresholdedRandomCutForestState> trcfSchema;

private final Class<? extends ThresholdingModel> thresholdingModelClass;

Expand All @@ -154,6 +160,8 @@ public class CheckpointDao {
* @param mapper RCF model serialization utility
* @param schema RandomCutForestState schema used by ProtoStuff
* @param converter converter from rcf v1 serde to protostuff based format
* @param trcfMapper TRCF serialization mapper
* @param trcfSchema TRCF serialization schema
* @param thresholdingModelClass thresholding model's class
* @param indexUtil Index utility methods
* @param maxCheckpointBytes max checkpoint size in bytes
Expand All @@ -168,6 +176,8 @@ public CheckpointDao(
RandomCutForestMapper mapper,
Schema<RandomCutForestState> schema,
V1JsonToV2StateConverter converter,
ThresholdedRandomCutForestMapper trcfMapper,
Schema<ThresholdedRandomCutForestState> trcfSchema,
Class<? extends ThresholdingModel> thresholdingModelClass,
AnomalyDetectionIndices indexUtil,
int maxCheckpointBytes,
Expand All @@ -181,6 +191,8 @@ public CheckpointDao(
this.mapper = mapper;
this.schema = schema;
this.converter = converter;
this.trcfMapper = trcfMapper;
this.trcfSchema = trcfSchema;
this.thresholdingModelClass = thresholdingModelClass;
this.indexUtil = indexUtil;
this.maxCheckpointBytes = maxCheckpointBytes;
Expand Down Expand Up @@ -337,10 +349,69 @@ public String toCheckpoint(EntityModel model) {
if (model.getThreshold() != null) {
json.addProperty(ENTITY_THRESHOLD, gson.toJson(model.getThreshold()));
}
if (model.getTrcf().isPresent()) {
json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get()));
}
return gson.toJson(json);
});
}

private String toCheckpoint(ThresholdedRandomCutForest trcf) {
String checkpoint = null;
Map.Entry<LinkedBuffer, Boolean> result = checkoutOrNewBuffer();
LinkedBuffer buffer = result.getKey();
boolean needCheckin = result.getValue();
try {
checkpoint = toCheckpoint(trcf, buffer);
} catch (RuntimeException e) {
logger.error("Failed to serialize model", e);
if (needCheckin) {
try {
serializeRCFBufferPool.invalidateObject(buffer);
needCheckin = false;
} catch (Exception x) {
logger.warn("Failed to invalidate buffer", x);
}
checkpoint = toCheckpoint(trcf, LinkedBuffer.allocate(serializeRCFBufferSize));
}
} finally {
if (needCheckin) {
try {
serializeRCFBufferPool.returnObject(buffer);
} catch (Exception e) {
logger.warn("Failed to return buffer to pool", e);
}
}
}
return checkpoint;
}

private Map.Entry<LinkedBuffer, Boolean> checkoutOrNewBuffer() {
LinkedBuffer buffer = null;
boolean isCheckout = true;
try {
buffer = serializeRCFBufferPool.borrowObject();
} catch (Exception e) {
logger.warn("Failed to borrow a buffer from pool", e);
}
if (buffer == null) {
buffer = LinkedBuffer.allocate(serializeRCFBufferSize);
isCheckout = false;
}
return new SimpleImmutableEntry(buffer, isCheckout);
}

private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) {
try {
ThresholdedRandomCutForestState trcfState = trcfMapper.toState(trcf);
byte[] bytes = AccessController
.doPrivileged((PrivilegedAction<byte[]>) () -> ProtostuffIOUtil.toByteArray(trcfState, trcfSchema, buffer));
return Base64.getEncoder().encodeToString(bytes);
} finally {
buffer.clear();
}
}

private String rcfModelToCheckpoint(RandomCutForest model) {
LinkedBuffer borrowedBuffer = null;
try {
Expand Down Expand Up @@ -466,7 +537,6 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
logger.warn(new ParameterizedMessage("Empty model for [{}]", modelId));
return Optional.empty();
}

String model = (String) modelObj;
if (model.length() > maxCheckpointBytes) {
logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length()));
Expand All @@ -483,15 +553,21 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
// avoid possible null pointer exception
samples = new ArrayDeque<>();
}
ThresholdedRandomCutForest trcf = null;
RandomCutForest rcf = null;
if (json.has(ENTITY_RCF)) {
String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString();
rcf = deserializeRCFModel(serializedRCF);
}
ThresholdingModel threshold = null;
if (json.has(ENTITY_THRESHOLD)) {
// verified, don't need privileged call to get permission
threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass);
if (json.has(ENTITY_TRCF)) {
trcf = toTrcf(json.getAsJsonPrimitive(ENTITY_TRCF).getAsString());
} else {
// TODO: convert rcf and threshold to trcf
if (json.has(ENTITY_RCF)) {
String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString();
rcf = deserializeRCFModel(serializedRCF);
}
if (json.has(ENTITY_THRESHOLD)) {
// verified, don't need privileged call to get permission
threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass);
}
}

String lastCheckpointTimeString = (String) (checkpoint.get(TIMESTAMP));
Expand All @@ -505,14 +581,34 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e);
}
}
return Optional.of(new SimpleImmutableEntry<>(new EntityModel(entity, samples, rcf, threshold), timestamp));
EntityModel entityModel = new EntityModel(entity, samples, rcf, threshold);
entityModel.setTrcf(trcf);
return Optional.of(new SimpleImmutableEntry<>(entityModel, timestamp));
});
} catch (Exception e) {
logger.warn("Exception while deserializing checkpoint", e);
throw e;
}
}

private ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest trcf = null;
if (checkpoint != null) {
try {
byte[] bytes = Base64.getDecoder().decode(checkpoint);
ThresholdedRandomCutForestState state = trcfSchema.newMessage();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
ProtostuffIOUtil.mergeFrom(bytes, state, trcfSchema);
return null;
});
trcf = trcfMapper.toModel(state);
} catch (RuntimeException e) {
logger.error("Failed to deserialize TRCF model", e);
}
}
return trcf;
}

private RandomCutForest deserializeRCFModel(String rcfCheckpoint) {
if (Strings.isEmpty(rcfCheckpoint)) {
return null;
Expand Down
63 changes: 9 additions & 54 deletions src/main/java/org/opensearch/ad/ml/EntityColdStarter.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -42,7 +43,6 @@
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -70,8 +70,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.threadpool.ThreadPool;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForest;

/**
* Training models for HCAD detectors
Expand Down Expand Up @@ -363,7 +363,7 @@ private void trainModelFromDataSegments(
}

int dimensions = dataPoints.get(0)[0].length * shingleSize;
RandomCutForest.Builder<?> rcfBuilder = RandomCutForest
ThresholdedRandomCutForest.Builder<?> rcfBuilder = ThresholdedRandomCutForest
.builder()
.dimensions(dimensions)
.sampleSize(rcfSampleSize)
Expand All @@ -380,73 +380,28 @@ private void trainModelFromDataSegments(
// vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize
// overlapping x3, x4, and only store x5, x6.
.shingleSize(shingleSize)
.internalShinglingEnabled(true);
.internalShinglingEnabled(true)
.anomalyRate(1 - this.thresholdMinPvalue);

if (rcfSeed > 0) {
rcfBuilder.randomSeed(rcfSeed);
}
RandomCutForest rcf = rcfBuilder.build();
List<double[]> allScores = new ArrayList<>();
int totalLength = 0;
// get continuous data points and send for training
for (double[][] continuousDataPoints : dataPoints) {
double[] scores = trainRCFModel(continuousDataPoints, rcf);
allScores.add(scores);
totalLength += scores.length;
}
ThresholdedRandomCutForest rcf = new ThresholdedRandomCutForest(rcfBuilder);

dataPoints.stream().flatMap(d -> Arrays.stream(d)).forEach(s -> rcf.process(s, 0));

EntityModel model = entityState.getModel();
if (model == null) {
model = new EntityModel(entity, new ArrayDeque<>(), null, null);
}
model.setRcf(rcf);
double[] joinedScores = new double[totalLength];

int destStart = 0;
for (double[] scores : allScores) {
System.arraycopy(scores, 0, joinedScores, destStart, scores.length);
destStart += scores.length;
}

// Train thresholding model
ThresholdingModel threshold = new HybridThresholdingModel(
thresholdMinPvalue,
thresholdMaxRankError,
thresholdMaxScore,
thresholdNumLogNormalQuantiles,
thresholdDownsamples,
thresholdMaxSamples
);
threshold.train(joinedScores);
model.setThreshold(threshold);
model.setTrcf(rcf);

entityState.setLastUsedTime(clock.instant());

// save to checkpoint
checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM);
}

/**
* Train the RCF model using given data points
* @param dataPoints Data points
* @param rcf RCF model to be trained
* @return scores returned by RCF models
*/
private double[] trainRCFModel(double[][] dataPoints, RandomCutForest rcf) {
if (dataPoints.length == 0 || dataPoints[0].length == 0) {
throw new IllegalArgumentException("Data points must not be empty.");
}

double[] scores = new double[dataPoints.length];

for (int j = 0; j < dataPoints.length; j++) {
scores[j] = rcf.getAnomalyScore(dataPoints[j]);
rcf.update(dataPoints[j]);
}

return DoubleStream.of(scores).filter(score -> score > 0).toArray();
}

/**
* Get training data for an entity.
*
Expand Down
Loading

0 comments on commit b0d8330

Please sign in to comment.