From b5b17cb99389004c9337532ec0294adc5c68068b Mon Sep 17 00:00:00 2001 From: Sudipto Guha Date: Fri, 15 Dec 2023 07:40:58 -0800 Subject: [PATCH] optimization and stability --- .../randomcutforest/RandomCutForest.java | 28 ++++++++++--------- .../preprocessor/PreprocessorTest.java | 4 +-- .../ThresholdedRandomCutForest.java | 2 +- .../parkservices/state/RCFCasterMapper.java | 1 + .../ThresholdedRandomCutForestMapper.java | 1 + 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java b/Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java index 7d0489a8..9a6cea74 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java @@ -1087,11 +1087,12 @@ public double[] imputeMissingValues(double[] point, int numberOfMissingValues, i * @return a forecasted time series. */ @Deprecated - public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) { + double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) { return toDoubleArray(extrapolateBasic(toFloatArray(point), horizon, blockSize, cyclic, shingleIndex)); } - public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) { + @Deprecated + float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) { return extrapolateWithRanges(point, horizon, blockSize, cyclic, shingleIndex, 1.0).values; } @@ -1122,7 +1123,9 @@ public RangeVector extrapolateWithRanges(float[] point, int horizon, int blockSi // external management of shingle; can function for both internal and external // shingling // however blocksize has to be externally managed - public RangeVector extrapolateFromShingle(float[] shingle, int horizon, int blockSize, double centrality) { + + @Deprecated + RangeVector extrapolateFromShingle(float[] shingle, int horizon, int blockSize, double centrality) { return extrapolateWithRanges(shingle, horizon, blockSize, isRotationEnabled(), ((int) nextSequenceIndex()) % shingleSize, centrality); } @@ -1142,11 +1145,11 @@ public RangeVector extrapolateFromShingle(float[] shingle, int horizon, int bloc * @return a forecasted time series. */ @Deprecated - public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic) { - return extrapolateBasic(point, horizon, blockSize, cyclic, 0); + double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic) { + return toDoubleArray(extrapolateBasic(toFloatArray(point), horizon, blockSize, cyclic, 0)); } - public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic) { + protected float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic) { return extrapolateBasic(point, horizon, blockSize, cyclic, 0); } @@ -1162,8 +1165,8 @@ public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boole */ @Deprecated public double[] extrapolateBasic(ShingleBuilder builder, int horizon) { - return extrapolateBasic(builder.getShingle(), horizon, builder.getInputPointSize(), builder.isCyclic(), - builder.getShingleIndex()); + return toDoubleArray(extrapolateBasic(toFloatArray(builder.getShingle()), horizon, builder.getInputPointSize(), + builder.isCyclic(), builder.getShingleIndex())); } void extrapolateBasicSliding(RangeVector result, int horizon, int blockSize, float[] queryPoint, @@ -1180,12 +1183,11 @@ void extrapolateBasicSliding(RangeVector result, int horizon, int blockSize, flo System.arraycopy(queryPoint, blockSize, queryPoint, 0, dimensions - blockSize); SampleSummary imputedSummary = getConditionalFieldSummary(queryPoint, missingIndexes, 1, 0, false, false, - centrality, 1); + centrality, dimensions / blockSize); for (int y = 0; y < blockSize; y++) { - result.values[resultIndex] = queryPoint[dimensions - blockSize + y] = imputedSummary.median[dimensions - - blockSize + y]; - result.lower[resultIndex] = imputedSummary.lower[dimensions - blockSize + y]; - result.upper[resultIndex] = imputedSummary.upper[dimensions - blockSize + y]; + result.values[resultIndex] = queryPoint[dimensions - blockSize + y] = imputedSummary.median[y]; + result.lower[resultIndex] = imputedSummary.lower[y]; + result.upper[resultIndex] = imputedSummary.upper[y]; resultIndex++; } } diff --git a/Java/core/src/test/java/com/amazon/randomcutforest/preprocessor/PreprocessorTest.java b/Java/core/src/test/java/com/amazon/randomcutforest/preprocessor/PreprocessorTest.java index 194c8752..43fcf318 100644 --- a/Java/core/src/test/java/com/amazon/randomcutforest/preprocessor/PreprocessorTest.java +++ b/Java/core/src/test/java/com/amazon/randomcutforest/preprocessor/PreprocessorTest.java @@ -244,8 +244,8 @@ public void preprocessorPlusForest(int seed, ForestMode mode, TransformMethod me preprocessor.update(dataWithKey.data[i], input, timestamp, new int[1], forest); } if (shingleSize > 1) { - RangeVector rangeVector = forest.extrapolateFromShingle(preprocessor.getLastShingledPoint(), 1, - tempDimensions / shingleSize, 1.0); + RangeVector rangeVector = forest.extrapolateWithRanges(preprocessor.getLastShingledPoint(), 1, + tempDimensions / shingleSize, false, 0, 1.0); TimedRangeVector timedRanges = preprocessor.invertForecastRange(rangeVector, timestamp, null, false, timestamp); // error of lookahead diff --git a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java index 7eb91f2e..12adb08d 100644 --- a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java +++ b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java @@ -383,7 +383,7 @@ public TimedRangeVector extrapolate(int horizon, boolean correct, double central preprocessor.getScale(), transformMethod, lastAnomalyDescriptor); } } - RangeVector answer = forest.extrapolateFromShingle(newPoint, horizon, blockSize, centrality); + RangeVector answer = forest.extrapolateWithRanges(newPoint, horizon, blockSize, false, 0, centrality); return preprocessor.invertForecastRange(answer, lastAnomalyDescriptor.getInputTimestamp(), lastAnomalyDescriptor.getDeltaShift(), lastAnomalyDescriptor.getExpectedRCFPoint() != null, lastAnomalyDescriptor.getExpectedTimeStamp()); diff --git a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/RCFCasterMapper.java b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/RCFCasterMapper.java index fc3093f5..837731e2 100644 --- a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/RCFCasterMapper.java +++ b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/RCFCasterMapper.java @@ -91,6 +91,7 @@ public RCFCaster toModel(RCFCasterState state, long seed) { descriptor.setTransformMethod(transformMethod); descriptor .setImputationMethod(ImputationMethod.valueOf(state.getPreprocessorStates()[0].getImputationMethod())); + descriptor.setShingleSize(preprocessor.getShingleSize()); PredictorCorrectorMapper mapper = new PredictorCorrectorMapper(); PredictorCorrector predictorCorrector = mapper.toModel(state.getPredictorCorrectorState()); diff --git a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/ThresholdedRandomCutForestMapper.java b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/ThresholdedRandomCutForestMapper.java index c06854fd..ad27ebcf 100644 --- a/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/ThresholdedRandomCutForestMapper.java +++ b/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/ThresholdedRandomCutForestMapper.java @@ -76,6 +76,7 @@ public ThresholdedRandomCutForest toModel(ThresholdedRandomCutForestState state, descriptor = new ComputeDescriptorMapper().toModel(state.getLastDescriptorState()); } + descriptor.setShingleSize(preprocessor.getShingleSize()); descriptor.setForestMode(forestMode); descriptor.setTransformMethod(transformMethod); descriptor.setScoringStrategy(scoringStrategy);