Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Confidence Adjustment for Larger Shingle Sizes #407

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
return point;
}

/**
* the timestamps are now used to calculate the number of imputed tuples in the
* shingle
*
* @param timestamp the timestamp of the current input
*/
@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the first
Copy link
Contributor

@amitgalitz amitgalitz Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least this part of the comment is still valid right? Would be good to have as I was pretty confused what previousTimeStamps[0] != previousTimeStamps[1] really meant to do. or maybe exaplain that part in
Additionally so I am understaind correct, we are saying if the last 10 values were imputed and shingle size is 8, we will most likely return false here and not allow an update until we get an actual value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me add the comment back.

we are saying if the last 10 values were imputed and shingle size is 8, we will most likely return false here and not allow an update until we get an actual value?

yes, your understanding is correct.

* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp. However, there are scenarios where we might miss
* decrementing numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. The second condition <pre> timestamp >
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
* will decrement numberOfImputed when we move to a new timestamp, provided
* numberOfImputed is greater than zero.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* decides if the forest should be updated, this is needed for imputation on the
* fly. The main goal of this function is to avoid runaway sequences where a
Expand All @@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) {
*/
protected boolean updateAllowed() {
double fraction = numberOfImputed * 1.0 / (shingleSize);
if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
if (fraction > 1) {
fraction = 1;
}
if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
&& (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) {
// this shingle is disconnected from the previously seen values
// these transformations will have little meaning
Expand All @@ -144,10 +109,19 @@ protected boolean updateAllowed() {
// two different points).
return false;
}

dataQuality[0].update(1 - fraction);
return (fraction < useImputedFraction && internalTimeStamp >= shingleSize);
}

@Override
protected void updateTimestamps(long timestamp) {
if (previousTimeStamps[0] == previousTimeStamps[1]) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* the following function mutates the forest, the lastShingledPoint,
* lastShingledInput as well as previousTimeStamps, and adds the shingled input
Expand All @@ -168,7 +142,9 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu
updateShingle(input, scaledInput);
updateTimestamps(timestamp);
if (isFullyImputed) {
numberOfImputed = numberOfImputed + 1;
numberOfImputed = Math.min(numberOfImputed + 1, shingleSize);
} else if (numberOfImputed > 0) {
numberOfImputed = numberOfImputed - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point are we guaranteed to have seen a new value that is not imputed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

}
if (changeForest) {
if (forest.isInternalShinglingEnabled()) {
Expand All @@ -190,7 +166,9 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi
return;
}
generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest);
++valuesSeen;
if (missing == null || missing.length != point.length) {
++valuesSeen;
}
}

protected double getTimeFactor(Deviation deviation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ public void setLastScore(double[] score) {
}

void validateIgnore(double[] shift, int length) {
checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
checkArgument(shift.length == length,
() -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
for (double element : shift) {
checkArgument(element >= 0, "has to be non-negative");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.stream.Stream;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;

public class MissingValueTest {
private static class EnumAndValueProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES)
.flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes
.map(shingleSize -> Arguments.of(method, shingleSize)));
}
}

@ParameterizedTest
@EnumSource(ImputationMethod.class)
public void testConfidence(ImputationMethod method) {
@ArgumentsSource(EnumAndValueProvider.class)
public void testConfidence(ImputationMethod method, int shingleSize) {
// Create and populate a random cut forest

int shingleSize = 4;
int numberOfTrees = 50;
int sampleSize = 256;
Precision precision = Precision.FLOAT_32;
Expand All @@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) {
long count = 0;

int dimensions = baseDimensions * shingleSize;
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true)
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
.forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true);

if (method == ImputationMethod.FIXED_VALUES) {
// we cannot pass fillValues when the method is not fixed values. Otherwise, we
// will impute
// filled in values irregardless of imputation method
forestBuilder.fillValues(new double[] { 3 });
}

ThresholdedRandomCutForest forest = forestBuilder.build();

// Define the size and range
int size = 400;
Expand All @@ -75,18 +97,38 @@ public void testConfidence(ImputationMethod method) {
float[] rcfPoint = result.getRCFPoint();
double scale = result.getScale()[0];
double shift = result.getShift()[0];
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift };
if (method == ImputationMethod.ZERO) {
assertEquals(0, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.FIXED_VALUES) {
assertEquals(3.0d, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.PREVIOUS) {
assertEquals(0, result.getAnomalyGrade(), 0.001d,
"count: " + count + " actual: " + Arrays.toString(actual));
}
} else {
AnomalyDescriptor result = forest.process(point, newStamp);
if ((count > 100 && count < 300) || count >= 326) {
// after 325, we have a period of confidence decreasing. After that, confidence
// starts increasing again.
// We are not sure where the confidence will start increasing after decreasing.
// So we start check the behavior after 325 + shingleSize.
int backupPoint = 325 + shingleSize;
if ((count > 100 && count < 300) || count >= backupPoint) {
// The first 65+ observations gives 0 confidence.
// Confidence start increasing after 1 observed point
assertTrue(result.getDataConfidence() > lastConfidence);
assertTrue(result.getDataConfidence() > lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
} else if (count < 325 && count > 300) {
assertTrue(result.getDataConfidence() < lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
}
lastConfidence = result.getDataConfidence();
}
Expand Down
Loading