Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Confidence Adjustment for Larger Shingle Sizes #407

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
return point;
}

/**
* the timestamps are now used to calculate the number of imputed tuples in the
* shingle
*
* @param timestamp the timestamp of the current input
*/
@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the first
Copy link
Contributor

@amitgalitz amitgalitz Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least this part of the comment is still valid right? Would be good to have as I was pretty confused what previousTimeStamps[0] != previousTimeStamps[1] really meant to do. or maybe exaplain that part in
Additionally so I am understaind correct, we are saying if the last 10 values were imputed and shingle size is 8, we will most likely return false here and not allow an update until we get an actual value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me add the comment back.

we are saying if the last 10 values were imputed and shingle size is 8, we will most likely return false here and not allow an update until we get an actual value?

yes, your understanding is correct.

* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp. However, there are scenarios where we might miss
* decrementing numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. The second condition <pre> timestamp >
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
* will decrement numberOfImputed when we move to a new timestamp, provided
* numberOfImputed is greater than zero.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* decides if the forest should be updated, this is needed for imputation on the
* fly. The main goal of this function is to avoid runaway sequences where a
Expand All @@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) {
*/
protected boolean updateAllowed() {
double fraction = numberOfImputed * 1.0 / (shingleSize);
if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
if (fraction > 1) {
fraction = 1;
}
if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
&& (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) {
// this shingle is disconnected from the previously seen values
// these transformations will have little meaning
Expand All @@ -144,10 +109,57 @@ protected boolean updateAllowed() {
// two different points).
return false;
}

dataQuality[0].update(1 - fraction);
return (fraction < useImputedFraction && internalTimeStamp >= shingleSize);
}

@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp.
*
* For example, if the last 10 values were imputed and the shingle size is 8,
* the condition will most likely return false until all 10 imputed values are
* removed from the shingle.
*
* However, there are scenarios where we might miss decrementing
* numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. To address the issue, we add code in method
* updateForest to decrement numberOfImputed when we move to a new timestamp,
* provided there is no imputation. This ensures th e imputation fraction does
* not increase as long as the imputation is continuing. This also ensures that
* the forest update decision, which relies on the imputation fraction,
* functions correctly. The forest is updated only when the imputation fraction
* is below the threshold of 0.5.
*
* Also, why can't we combine the decrement code between updateTimestamps and
* updateForest together? This would cause Consistency.ImputeTest to fail when
* testing with and without imputation, as the RCF scores would not change. The
* method updateTimestamps is used in other places (e.g., updateState and
* dischargeInitial), not only in updateForest.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* the following function mutates the forest, the lastShingledPoint,
* lastShingledInput as well as previousTimeStamps, and adds the shingled input
Expand All @@ -168,7 +180,13 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu
updateShingle(input, scaledInput);
updateTimestamps(timestamp);
if (isFullyImputed) {
numberOfImputed = numberOfImputed + 1;
// The numImputed is now capped at the shingle size to ensure that the impute
// fraction,
// calculated as numberOfImputed * 1.0 / shingleSize, does not exceed 1.
numberOfImputed = Math.min(numberOfImputed + 1, shingleSize);
} else if (numberOfImputed > 0) {
// Decrement numberOfImputed when the new value is not imputed
numberOfImputed = numberOfImputed - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point are we guaranteed to have seen a new value that is not imputed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

}
if (changeForest) {
if (forest.isInternalShinglingEnabled()) {
Expand All @@ -190,7 +208,14 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi
return;
}
generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest);
++valuesSeen;
// The confidence formula depends on numImputed (the number of recent
// imputations seen)
// and seenValues (all values seen). To ensure confidence decreases when
// numImputed increases,
// we need to count only non-imputed values as seenValues.
if (missing == null || missing.length != point.length) {
++valuesSeen;
}
}

protected double getTimeFactor(Deviation deviation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ public void setLastScore(double[] score) {
}

void validateIgnore(double[] shift, int length) {
checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
checkArgument(shift.length == length,
() -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
for (double element : shift) {
checkArgument(element >= 0, "has to be non-negative");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.stream.Stream;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;

public class MissingValueTest {
private static class EnumAndValueProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES)
.flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes
.map(shingleSize -> Arguments.of(method, shingleSize)));
}
}

@ParameterizedTest
@EnumSource(ImputationMethod.class)
public void testConfidence(ImputationMethod method) {
@ArgumentsSource(EnumAndValueProvider.class)
public void testConfidence(ImputationMethod method, int shingleSize) {
// Create and populate a random cut forest

int shingleSize = 4;
int numberOfTrees = 50;
int sampleSize = 256;
Precision precision = Precision.FLOAT_32;
Expand All @@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) {
long count = 0;

int dimensions = baseDimensions * shingleSize;
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true)
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
.forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true);

if (method == ImputationMethod.FIXED_VALUES) {
// we cannot pass fillValues when the method is not fixed values. Otherwise, we
// will impute
// filled in values irregardless of imputation method
forestBuilder.fillValues(new double[] { 3 });
}

ThresholdedRandomCutForest forest = forestBuilder.build();

// Define the size and range
int size = 400;
Expand All @@ -75,18 +97,38 @@ public void testConfidence(ImputationMethod method) {
float[] rcfPoint = result.getRCFPoint();
double scale = result.getScale()[0];
double shift = result.getShift()[0];
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift };
if (method == ImputationMethod.ZERO) {
assertEquals(0, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.FIXED_VALUES) {
assertEquals(3.0d, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.PREVIOUS) {
assertEquals(0, result.getAnomalyGrade(), 0.001d,
"count: " + count + " actual: " + Arrays.toString(actual));
}
} else {
AnomalyDescriptor result = forest.process(point, newStamp);
if ((count > 100 && count < 300) || count >= 326) {
// after 325, we have a period of confidence decreasing. After that, confidence
// starts increasing again.
// We are not sure where the confidence will start increasing after decreasing.
// So we start check the behavior after 325 + shingleSize.
int backupPoint = 325 + shingleSize;
if ((count > 100 && count < 300) || count >= backupPoint) {
// The first 65+ observations gives 0 confidence.
// Confidence start increasing after 1 observed point
assertTrue(result.getDataConfidence() > lastConfidence);
assertTrue(result.getDataConfidence() > lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
} else if (count < 325 && count > 300) {
assertTrue(result.getDataConfidence() < lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
}
lastConfidence = result.getDataConfidence();
}
Expand Down
Loading