diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java b/Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java index d00a15ee..71f4a8ca 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java @@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[] return point; } - /** - * the timestamps are now used to calculate the number of imputed tuples in the - * shingle - * - * @param timestamp the timestamp of the current input - */ - @Override - protected void updateTimestamps(long timestamp) { - /* - * For imputations done on timestamps other than the current one (specified by - * the timestamp parameter), the timestamp of the imputed tuple matches that of - * the input tuple, and we increment numberOfImputed. For imputations done at - * the current timestamp (if all input values are missing), the timestamp of the - * imputed tuple is the current timestamp, and we increment numberOfImputed. - * - * To check if imputed values are still present in the shingle, we use the first - * condition (previousTimeStamps[0] == previousTimeStamps[1]). This works - * because previousTimeStamps has a size equal to the shingle size and is filled - * with the current timestamp. However, there are scenarios where we might miss - * decrementing numberOfImputed: - * - * 1. Not all values in the shingle are imputed. 2. We accumulated - * numberOfImputed when the current timestamp had missing values. - * - * As a result, this could cause the data quality measure to decrease - * continuously since we are always counting missing values that should - * eventually be reset to zero. The second condition
timestamp > - * previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0- * will decrement numberOfImputed when we move to a new timestamp, provided - * numberOfImputed is greater than zero. - */ - if (previousTimeStamps[0] == previousTimeStamps[1] - || (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) { - numberOfImputed = numberOfImputed - 1; - } - super.updateTimestamps(timestamp); - } - /** * decides if the forest should be updated, this is needed for imputation on the * fly. The main goal of this function is to avoid runaway sequences where a @@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) { */ protected boolean updateAllowed() { double fraction = numberOfImputed * 1.0 / (shingleSize); - if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1] + if (fraction > 1) { + fraction = 1; + } + if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1] && (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) { // this shingle is disconnected from the previously seen values // these transformations will have little meaning @@ -144,10 +109,19 @@ protected boolean updateAllowed() { // two different points). return false; } + dataQuality[0].update(1 - fraction); return (fraction < useImputedFraction && internalTimeStamp >= shingleSize); } + @Override + protected void updateTimestamps(long timestamp) { + if (previousTimeStamps[0] == previousTimeStamps[1]) { + numberOfImputed = numberOfImputed - 1; + } + super.updateTimestamps(timestamp); + } + /** * the following function mutates the forest, the lastShingledPoint, * lastShingledInput as well as previousTimeStamps, and adds the shingled input @@ -168,7 +142,9 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu updateShingle(input, scaledInput); updateTimestamps(timestamp); if (isFullyImputed) { - numberOfImputed = numberOfImputed + 1; + numberOfImputed = Math.min(numberOfImputed + 1, shingleSize); + } else if (numberOfImputed > 0) { + numberOfImputed = numberOfImputed - 1; } if (changeForest) { if (forest.isInternalShinglingEnabled()) { @@ -190,7 +166,9 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi return; } generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest); - ++valuesSeen; + if (missing == null || missing.length != point.length) { + ++valuesSeen; + } } protected double getTimeFactor(Deviation deviation) { diff --git a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java index a6100a01..08abbace 100644 --- a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java +++ b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java @@ -961,7 +961,8 @@ public void setLastScore(double[] score) { } void validateIgnore(double[] shift, int length) { - checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length)); + checkArgument(shift.length == length, + () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length)); for (double element : shift) { checkArgument(element >= 0, "has to be non-negative"); } diff --git a/Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/MissingValueTest.java b/Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/MissingValueTest.java index 829e4dc2..cc48bc3b 100644 --- a/Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/MissingValueTest.java +++ b/Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/MissingValueTest.java @@ -19,11 +19,17 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Random; +import java.util.stream.Stream; +import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; import com.amazon.randomcutforest.config.ForestMode; import com.amazon.randomcutforest.config.ImputationMethod; @@ -31,12 +37,20 @@ import com.amazon.randomcutforest.config.TransformMethod; public class MissingValueTest { + private static class EnumAndValueProvider implements ArgumentsProvider { + @Override + public Stream extends Arguments> provideArguments(ExtensionContext context) { + return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES) + .flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes + .map(shingleSize -> Arguments.of(method, shingleSize))); + } + } + @ParameterizedTest - @EnumSource(ImputationMethod.class) - public void testConfidence(ImputationMethod method) { + @ArgumentsSource(EnumAndValueProvider.class) + public void testConfidence(ImputationMethod method, int shingleSize) { // Create and populate a random cut forest - int shingleSize = 4; int numberOfTrees = 50; int sampleSize = 256; Precision precision = Precision.FLOAT_32; @@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) { long count = 0; int dimensions = baseDimensions * shingleSize; - ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true) + ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true) .dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize) .sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method) - .fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE) - .transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build(); + .forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true); + + if (method == ImputationMethod.FIXED_VALUES) { + // we cannot pass fillValues when the method is not fixed values. Otherwise, we + // will impute + // filled in values irregardless of imputation method + forestBuilder.fillValues(new double[] { 3 }); + } + + ThresholdedRandomCutForest forest = forestBuilder.build(); // Define the size and range int size = 400; @@ -75,18 +97,38 @@ public void testConfidence(ImputationMethod method) { float[] rcfPoint = result.getRCFPoint(); double scale = result.getScale()[0]; double shift = result.getShift()[0]; - double[] actual = new double[] { (rcfPoint[3] * scale) + shift }; + double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift }; if (method == ImputationMethod.ZERO) { assertEquals(0, actual[0], 0.001d); + if (count == 300) { + assertTrue(result.getAnomalyGrade() > 0); + } } else if (method == ImputationMethod.FIXED_VALUES) { assertEquals(3.0d, actual[0], 0.001d); + if (count == 300) { + assertTrue(result.getAnomalyGrade() > 0); + } + } else if (method == ImputationMethod.PREVIOUS) { + assertEquals(0, result.getAnomalyGrade(), 0.001d, + "count: " + count + " actual: " + Arrays.toString(actual)); } } else { AnomalyDescriptor result = forest.process(point, newStamp); - if ((count > 100 && count < 300) || count >= 326) { + // after 325, we have a period of confidence decreasing. After that, confidence + // starts increasing again. + // We are not sure where the confidence will start increasing after decreasing. + // So we start check the behavior after 325 + shingleSize. + int backupPoint = 325 + shingleSize; + if ((count > 100 && count < 300) || count >= backupPoint) { // The first 65+ observations gives 0 confidence. // Confidence start increasing after 1 observed point - assertTrue(result.getDataConfidence() > lastConfidence); + assertTrue(result.getDataConfidence() > lastConfidence, + String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count, + result.getDataConfidence(), lastConfidence)); + } else if (count < 325 && count > 300) { + assertTrue(result.getDataConfidence() < lastConfidence, + String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count, + result.getDataConfidence(), lastConfidence)); } lastConfidence = result.getDataConfidence(); }