diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index f01c563a..edec2ace 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -290,6 +290,7 @@ public Collection createComponents( AnomalyDetectorSettings.NUM_TREES, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index 1b94d3ca..bc6d2e25 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -91,6 +91,7 @@ public String getName() { private final int rcfNumTrees; private final int rcfNumSamplesInTree; private final double rcfTimeDecay; + private final int rcfNumMinSamples; private final double thresholdMinPvalue; private final double thresholdMaxRankError; private final double thresholdMaxScore; @@ -130,6 +131,7 @@ public String getName() { * @param rcfNumTrees number of trees used in RCF * @param rcfNumSamplesInTree number of samples in a RCF tree * @param rcfTimeDecay time decay for RCF + * @param rcfNumMinSamples minimum samples for RCF to score * @param thresholdMinPvalue min P-value for thresholding * @param thresholdMaxRankError max rank error for thresholding * @param thresholdMaxScore max RCF score to thresholding @@ -154,6 +156,7 @@ public ModelManager( int rcfNumTrees, int rcfNumSamplesInTree, double rcfTimeDecay, + int rcfNumMinSamples, double thresholdMinPvalue, double thresholdMaxRankError, double thresholdMaxScore, @@ -179,6 +182,7 @@ public ModelManager( this.rcfNumTrees = rcfNumTrees; this.rcfNumSamplesInTree = rcfNumSamplesInTree; this.rcfTimeDecay = rcfTimeDecay; + this.rcfNumMinSamples = rcfNumMinSamples; this.thresholdMinPvalue = thresholdMinPvalue; this.thresholdMaxRankError = thresholdMaxRankError; this.thresholdMaxScore = thresholdMaxScore; @@ -679,7 +683,7 @@ public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints) { .sampleSize(rcfNumSamplesInTree) .numberOfTrees(forestSize) .lambda(rcfTimeDecay) - .outputAfter(rcfNumSamplesInTree) + .outputAfter(rcfNumMinSamples) .parallelExecutionEnabled(false) .build(); for (int j = 0; j < dataPoints.length; j++) { @@ -771,7 +775,7 @@ private void trainModelForStep( .sampleSize(rcfNumSamplesInTree) .numberOfTrees(forestSize) .lambda(rcfTimeDecay) - .outputAfter(rcfNumSamplesInTree) + .outputAfter(rcfNumMinSamples) .parallelExecutionEnabled(false) .build(); for (int j = 0; j < dataPoints.length; j++) { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java index 2e22dbe7..5b5b267b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java @@ -159,6 +159,8 @@ private AnomalyDetectorSettings() {} public static final double TIME_DECAY = 0.0001; + public static final int NUM_MIN_SAMPLES = 128; + public static final double DESIRED_MODEL_SIZE_PERCENTAGE = 0.0002; public static final double MODEL_MAX_SIZE_PERCENTAGE = 0.1; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 25ab3b87..4477f481 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -110,6 +110,7 @@ public class ModelManagerTests { private int numSamples; private int numFeatures; private double rcfTimeDecay; + private int numMinSamples; private double thresholdMinPvalue; private double thresholdMaxRankError; private double thresholdMaxScore; @@ -143,6 +144,7 @@ public void setup() { numSamples = 10; numFeatures = 1; rcfTimeDecay = 1.0 / 1024; + numMinSamples = 1; thresholdMinPvalue = 0.95; thresholdMaxRankError = 1e-4; thresholdMaxScore = 8.0; @@ -174,6 +176,7 @@ public void setup() { numTrees, numSamples, rcfTimeDecay, + numMinSamples, thresholdMinPvalue, thresholdMaxRankError, thresholdMaxScore,