Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add thresholded rcf #215

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,13 @@ dependencies {
// implementation scope let the dependency in both compiling and running classpath, but
// not leaked through to clients (Opensearch). Here we force the jackson version to whatever
// opensearch uses.
compile 'software.amazon.randomcutforest:randomcutforest-core:2.0.1'
//compile 'software.amazon.randomcutforest:randomcutforest-core:2.0.1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:2.0.1'
implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}"
implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}"
implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}"
compile files('lib/randomcutforest-parkservices-2.0.1.jar')
compile files('lib/randomcutforest-core-2.0.1.jar')

// used for serializing/deserializing rcf models.
compile group: 'io.protostuff', name: 'protostuff-core', version: '1.7.4'
Expand Down
Binary file added lib/randomcutforest-core-2.0.1.jar
Binary file not shown.
Binary file added lib/randomcutforest-parkservices-2.0.1.jar
Binary file not shown.
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
Expand Down Expand Up @@ -450,6 +452,12 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
mapper,
schema,
converter,
new ThresholdedRandomCutForestMapper(),
AccessController
.doPrivileged(
(PrivilegedAction<Schema<ThresholdedRandomCutForestState>>) () -> RuntimeSchema
.getSchema(ThresholdedRandomCutForestState.class)
),
HybridThresholdingModel.class,
anomalyDetectionIndices,
AnomalyDetectorSettings.MAX_CHECKPOINT_BYTES,
Expand Down
114 changes: 105 additions & 9 deletions src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
Expand Down Expand Up @@ -116,6 +119,7 @@ public class CheckpointDao {
public static final String ENTITY_SAMPLE = "sp";
public static final String ENTITY_RCF = "rcf";
public static final String ENTITY_THRESHOLD = "th";
public static final String ENTITY_TRCF = "trcf";
public static final String FIELD_MODEL = "model";
public static final String FIELD_MODELV2 = "modelV2";
public static final String TIMESTAMP = "timestamp";
Expand All @@ -132,6 +136,8 @@ public class CheckpointDao {
private RandomCutForestMapper mapper;
private Schema<RandomCutForestState> schema;
private V1JsonToV2StateConverter converter;
private ThresholdedRandomCutForestMapper trcfMapper;
private Schema<ThresholdedRandomCutForestState> trcfSchema;

private final Class<? extends ThresholdingModel> thresholdingModelClass;

Expand All @@ -154,6 +160,8 @@ public class CheckpointDao {
* @param mapper RCF model serialization utility
* @param schema RandomCutForestState schema used by ProtoStuff
* @param converter converter from rcf v1 serde to protostuff based format
* @param trcfMapper TRCF serialization mapper
* @param trcfSchema TRCF serialization schema
* @param thresholdingModelClass thresholding model's class
* @param indexUtil Index utility methods
* @param maxCheckpointBytes max checkpoint size in bytes
Expand All @@ -168,6 +176,8 @@ public CheckpointDao(
RandomCutForestMapper mapper,
Schema<RandomCutForestState> schema,
V1JsonToV2StateConverter converter,
ThresholdedRandomCutForestMapper trcfMapper,
Schema<ThresholdedRandomCutForestState> trcfSchema,
Class<? extends ThresholdingModel> thresholdingModelClass,
AnomalyDetectionIndices indexUtil,
int maxCheckpointBytes,
Expand All @@ -181,6 +191,8 @@ public CheckpointDao(
this.mapper = mapper;
this.schema = schema;
this.converter = converter;
this.trcfMapper = trcfMapper;
this.trcfSchema = trcfSchema;
this.thresholdingModelClass = thresholdingModelClass;
this.indexUtil = indexUtil;
this.maxCheckpointBytes = maxCheckpointBytes;
Expand Down Expand Up @@ -337,10 +349,69 @@ public String toCheckpoint(EntityModel model) {
if (model.getThreshold() != null) {
json.addProperty(ENTITY_THRESHOLD, gson.toJson(model.getThreshold()));
}
if (model.getTrcf().isPresent()) {
json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get()));
}
return gson.toJson(json);
});
}

private String toCheckpoint(ThresholdedRandomCutForest trcf) {
String checkpoint = null;
Map.Entry<LinkedBuffer, Boolean> result = checkoutOrNewBuffer();
LinkedBuffer buffer = result.getKey();
boolean needCheckin = result.getValue();
try {
checkpoint = toCheckpoint(trcf, buffer);
} catch (RuntimeException e) {
logger.error("Failed to serialize model", e);
if (needCheckin) {
try {
serializeRCFBufferPool.invalidateObject(buffer);
needCheckin = false;
} catch (Exception x) {
logger.warn("Failed to invalidate buffer", x);
}
checkpoint = toCheckpoint(trcf, LinkedBuffer.allocate(serializeRCFBufferSize));
}
} finally {
if (needCheckin) {
try {
serializeRCFBufferPool.returnObject(buffer);
} catch (Exception e) {
logger.warn("Failed to return buffer to pool", e);
}
}
}
return checkpoint;
}

private Map.Entry<LinkedBuffer, Boolean> checkoutOrNewBuffer() {
LinkedBuffer buffer = null;
boolean isCheckout = true;
try {
buffer = serializeRCFBufferPool.borrowObject();
} catch (Exception e) {
logger.warn("Failed to borrow a buffer from pool", e);
}
if (buffer == null) {
buffer = LinkedBuffer.allocate(serializeRCFBufferSize);
isCheckout = false;
}
return new SimpleImmutableEntry(buffer, isCheckout);
}

private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) {
try {
ThresholdedRandomCutForestState trcfState = trcfMapper.toState(trcf);
byte[] bytes = AccessController
.doPrivileged((PrivilegedAction<byte[]>) () -> ProtostuffIOUtil.toByteArray(trcfState, trcfSchema, buffer));
return Base64.getEncoder().encodeToString(bytes);
} finally {
buffer.clear();
}
}

private String rcfModelToCheckpoint(RandomCutForest model) {
LinkedBuffer borrowedBuffer = null;
try {
Expand Down Expand Up @@ -466,7 +537,6 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
logger.warn(new ParameterizedMessage("Empty model for [{}]", modelId));
return Optional.empty();
}

String model = (String) modelObj;
if (model.length() > maxCheckpointBytes) {
logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length()));
Expand All @@ -483,15 +553,21 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
// avoid possible null pointer exception
samples = new ArrayDeque<>();
}
ThresholdedRandomCutForest trcf = null;
RandomCutForest rcf = null;
if (json.has(ENTITY_RCF)) {
String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString();
rcf = deserializeRCFModel(serializedRCF);
}
ThresholdingModel threshold = null;
if (json.has(ENTITY_THRESHOLD)) {
// verified, don't need privileged call to get permission
threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass);
if (json.has(ENTITY_TRCF)) {
trcf = toTrcf(json.getAsJsonPrimitive(ENTITY_TRCF).getAsString());
} else {
// TODO: convert rcf and threshold to trcf
if (json.has(ENTITY_RCF)) {
String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString();
rcf = deserializeRCFModel(serializedRCF);
}
if (json.has(ENTITY_THRESHOLD)) {
// verified, don't need privileged call to get permission
threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass);
}
}

String lastCheckpointTimeString = (String) (checkpoint.get(TIMESTAMP));
Expand All @@ -505,14 +581,34 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e);
}
}
return Optional.of(new SimpleImmutableEntry<>(new EntityModel(entity, samples, rcf, threshold), timestamp));
EntityModel entityModel = new EntityModel(entity, samples, rcf, threshold);
entityModel.setTrcf(trcf);
return Optional.of(new SimpleImmutableEntry<>(entityModel, timestamp));
});
} catch (Exception e) {
logger.warn("Exception while deserializing checkpoint", e);
throw e;
}
}

private ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest trcf = null;
if (checkpoint != null) {
try {
byte[] bytes = Base64.getDecoder().decode(checkpoint);
ThresholdedRandomCutForestState state = trcfSchema.newMessage();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
ProtostuffIOUtil.mergeFrom(bytes, state, trcfSchema);
return null;
});
trcf = trcfMapper.toModel(state);
} catch (RuntimeException e) {
logger.error("Failed to deserialize TRCF model", e);
}
}
return trcf;
}

private RandomCutForest deserializeRCFModel(String rcfCheckpoint) {
if (Strings.isEmpty(rcfCheckpoint)) {
return null;
Expand Down
63 changes: 9 additions & 54 deletions src/main/java/org/opensearch/ad/ml/EntityColdStarter.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -42,7 +43,6 @@
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -70,8 +70,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.threadpool.ThreadPool;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.threshold.ThresholdedRandomCutForest;

/**
* Training models for HCAD detectors
Expand Down Expand Up @@ -363,7 +363,7 @@ private void trainModelFromDataSegments(
}

int dimensions = dataPoints.get(0)[0].length * shingleSize;
RandomCutForest.Builder<?> rcfBuilder = RandomCutForest
ThresholdedRandomCutForest.Builder<?> rcfBuilder = ThresholdedRandomCutForest
.builder()
.dimensions(dimensions)
.sampleSize(rcfSampleSize)
Expand All @@ -380,73 +380,28 @@ private void trainModelFromDataSegments(
// vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize
// overlapping x3, x4, and only store x5, x6.
.shingleSize(shingleSize)
.internalShinglingEnabled(true);
.internalShinglingEnabled(true)
.anomalyRate(1 - this.thresholdMinPvalue);

if (rcfSeed > 0) {
rcfBuilder.randomSeed(rcfSeed);
}
RandomCutForest rcf = rcfBuilder.build();
List<double[]> allScores = new ArrayList<>();
int totalLength = 0;
// get continuous data points and send for training
for (double[][] continuousDataPoints : dataPoints) {
double[] scores = trainRCFModel(continuousDataPoints, rcf);
allScores.add(scores);
totalLength += scores.length;
}
ThresholdedRandomCutForest rcf = new ThresholdedRandomCutForest(rcfBuilder);

dataPoints.stream().flatMap(d -> Arrays.stream(d)).forEach(s -> rcf.process(s, 0));

EntityModel model = entityState.getModel();
if (model == null) {
model = new EntityModel(entity, new ArrayDeque<>(), null, null);
}
model.setRcf(rcf);
double[] joinedScores = new double[totalLength];

int destStart = 0;
for (double[] scores : allScores) {
System.arraycopy(scores, 0, joinedScores, destStart, scores.length);
destStart += scores.length;
}

// Train thresholding model
ThresholdingModel threshold = new HybridThresholdingModel(
thresholdMinPvalue,
thresholdMaxRankError,
thresholdMaxScore,
thresholdNumLogNormalQuantiles,
thresholdDownsamples,
thresholdMaxSamples
);
threshold.train(joinedScores);
model.setThreshold(threshold);
model.setTrcf(rcf);

entityState.setLastUsedTime(clock.instant());

// save to checkpoint
checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM);
}

/**
* Train the RCF model using given data points
* @param dataPoints Data points
* @param rcf RCF model to be trained
* @return scores returned by RCF models
*/
private double[] trainRCFModel(double[][] dataPoints, RandomCutForest rcf) {
if (dataPoints.length == 0 || dataPoints[0].length == 0) {
throw new IllegalArgumentException("Data points must not be empty.");
}

double[] scores = new double[dataPoints.length];

for (int j = 0; j < dataPoints.length; j++) {
scores[j] = rcf.getAnomalyScore(dataPoints[j]);
rcf.update(dataPoints[j]);
}

return DoubleStream.of(scores).filter(score -> score > 0).toArray();
}

/**
* Get training data for an entity.
*
Expand Down
Loading