diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 8a16b0790..910f6638a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -19,7 +19,7 @@ jobs: strategy: matrix: # each test scenario (rule, hc, single_stream) is treated as a separate job. - test: [rule, hc, single_stream] + test: [rule, hc, single_stream,missing] fail-fast: false concurrency: # The concurrency setting is used to limit the concurrency of each test scenario group to ensure they do not run concurrently on the same machine. @@ -48,11 +48,16 @@ jobs: chown -R 1000:1000 `pwd` case ${{ matrix.test }} in rule) - su `id -un 1000` -c "./gradlew integTest --tests 'org.opensearch.ad.e2e.RuleModelPerfIT' \ + su `id -un 1000` -c "./gradlew integTest --tests 'org.opensearch.ad.e2e.RealTimeRuleModelPerfIT' \ -Dtests.seed=B4BA12CCF1D9E825 -Dtests.security.manager=false \ -Dtests.jvm.argline='-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m' \ -Dtests.locale=ar-JO -Dtests.timezone=Asia/Samarkand -Dmodel-benchmark=true \ -Dtests.timeoutSuite=3600000! -Dtest.logs=true" + su `id -un 1000` -c "./gradlew integTest --tests 'org.opensearch.ad.e2e.HistoricalRuleModelPerfIT' \ + -Dtests.seed=B4BA12CCF1D9E825 -Dtests.security.manager=false \ + -Dtests.jvm.argline='-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m' \ + -Dtests.locale=ar-JO -Dtests.timezone=Asia/Samarkand -Dmodel-benchmark=true \ + -Dtests.timeoutSuite=3600000! -Dtest.logs=true" ;; hc) su `id -un 1000` -c "./gradlew ':test' --tests 'org.opensearch.ad.ml.HCADModelPerfTests' \ @@ -66,4 +71,10 @@ jobs: -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true \ -Dtests.timeoutSuite=3600000! -Dmodel-benchmark=true" ;; + missing) + su `id -un 1000` -c "./gradlew integTest --tests 'org.opensearch.ad.e2e.RealTimeMissingSingleFeatureModelPerfIT' \ + -Dtests.seed=60CDDB34427ACD0C -Dtests.security.manager=false \ + -Dtests.locale=kab-DZ -Dtests.timezone=Asia/Hebron -Dtest.logs=true \ + -Dtests.timeoutSuite=3600000! -Dmodel-benchmark=true" + ;; esac diff --git a/build.gradle b/build.gradle index b31561343..2812fdeb0 100644 --- a/build.gradle +++ b/build.gradle @@ -126,9 +126,12 @@ dependencies { implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2' implementation group: 'commons-lang', name: 'commons-lang', version: '2.6' implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0' - implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.0.0' - implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.0.0' - implementation 'software.amazon.randomcutforest:randomcutforest-core:4.0.0' + // implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.0.0' + // implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.0.0' + // implementation 'software.amazon.randomcutforest:randomcutforest-core:4.0.0' + implementation files('lib/randomcutforest-core-4.1.0.jar') + implementation files('lib/randomcutforest-parkservices-4.1.0.jar') + implementation files('lib/randomcutforest-serialization-4.1.0.jar') // we inherit jackson-core from opensearch core implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1" @@ -356,8 +359,7 @@ integTest { if (System.getProperty("model-benchmark") == null || System.getProperty("model-benchmark") == "false") { filter { - excludeTestsMatching "org.opensearch.ad.e2e.SingleStreamModelPerfIT" - excludeTestsMatching "org.opensearch.ad.e2e.RuleModelPerfIT" + excludeTestsMatching "org.opensearch.ad.e2e.*ModelPerfIT" } } @@ -676,11 +678,6 @@ List jacocoExclusions = [ // rest layer is tested in integration testing mostly, difficult to mock all of it 'org.opensearch.ad.rest.*', - 'org.opensearch.ad.model.ModelProfileOnNode', - 'org.opensearch.ad.model.InitProgressProfile', - 'org.opensearch.ad.rest.*', - 'org.opensearch.ad.AnomalyDetectorJobRunner', - // Class containing just constants. Don't need to test 'org.opensearch.ad.constant.*', 'org.opensearch.forecast.constant.*', @@ -688,22 +685,49 @@ List jacocoExclusions = [ 'org.opensearch.timeseries.settings.TimeSeriesSettings', 'org.opensearch.forecast.settings.ForecastSettings', - 'org.opensearch.ad.transport.CronRequest', - 'org.opensearch.ad.AnomalyDetectorRunner', - // related to transport actions added for security 'org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction.1', // TODO: unified flow caused coverage drop 'org.opensearch.ad.transport.DeleteAnomalyResultsTransportAction', - // TODO: fix unstable code coverage caused by null NodeClient issue - // https://github.com/opensearch-project/anomaly-detection/issues/241 - 'org.opensearch.ad.task.ADBatchTaskRunner', - 'org.opensearch.ad.task.ADTaskManager', - // TODO: add forecast test coverage before release + + // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.*', - 'org.opensearch.ad.*', + 'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction', + 'org.opensearch.ad.ml.ADColdStart', + 'org.opensearch.ad.transport.ADHCImputeNodesResponse', + 'org.opensearch.timeseries.transport.BooleanNodeResponse', + 'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao', + 'org.opensearch.timeseries.transport.JobRequest', + 'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler', + 'org.opensearch.timeseries.ml.Inferencer', + 'org.opensearch.timeseries.transport.SingleStreamResultRequest', + 'org.opensearch.timeseries.transport.BooleanResponse', + 'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1', + 'org.opensearch.timeseries.transport.SuggestConfigParamResponse', + 'org.opensearch.timeseries.transport.SuggestConfigParamRequest', + 'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap', + 'org.opensearch.timeseries.transport.ResultBulkTransportAction', + 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', + 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', + 'org.opensearch.ad.transport.ADHCImputeNodeResponse', + 'org.opensearch.timeseries.ml.Sample', + 'org.opensearch.timeseries.ratelimit.FeatureRequest', + 'org.opensearch.ad.transport.ADHCImputeNodeRequest', + 'org.opensearch.timeseries.model.ModelProfileOnNode', + 'org.opensearch.timeseries.transport.ValidateConfigRequest', + 'org.opensearch.timeseries.transport.ResultProcessor.PageListener.1', + 'org.opensearch.ad.transport.ADHCImputeRequest', + 'org.opensearch.timeseries.transport.BaseDeleteConfigTransportAction.1', + 'org.opensearch.timeseries.transport.BaseSuggestConfigParamTransportAction', + 'org.opensearch.timeseries.rest.AbstractSearchAction.1', + 'org.opensearch.ad.transport.ADSingleStreamResultTransportAction', + 'org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker.RequestQueue', + 'org.opensearch.timeseries.rest.RestStatsAction', + 'org.opensearch.ad.ml.ADCheckpointDao', + 'org.opensearch.timeseries.transport.CronRequest', + 'org.opensearch.ad.task.ADBatchTaskCache', + 'org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker', ] diff --git a/lib/randomcutforest-core-4.1.0.jar b/lib/randomcutforest-core-4.1.0.jar new file mode 100644 index 000000000..5d76ddf6a Binary files /dev/null and b/lib/randomcutforest-core-4.1.0.jar differ diff --git a/lib/randomcutforest-parkservices-4.1.0.jar b/lib/randomcutforest-parkservices-4.1.0.jar new file mode 100644 index 000000000..5810ff420 Binary files /dev/null and b/lib/randomcutforest-parkservices-4.1.0.jar differ diff --git a/lib/randomcutforest-serialization-4.1.0.jar b/lib/randomcutforest-serialization-4.1.0.jar new file mode 100644 index 000000000..6d3ceaccd Binary files /dev/null and b/lib/randomcutforest-serialization-4.1.0.jar differ diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java index fa16ed2c7..a68645396 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java @@ -102,8 +102,7 @@ public void executeDetector( startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> { - List entityResults = modelManager - .getPreviewResults(features, detector.getShingleSize(), detector.getTimeDecay()); + List entityResults = modelManager.getPreviewResults(features, detector); List sampledEntityResults = sample( parsePreviewResult(detector, features, entityResults, entity), maxPreviewResults @@ -116,8 +115,7 @@ public void executeDetector( } else { featureManager.getPreviewFeatures(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> { try { - List results = modelManager - .getPreviewResults(features, detector.getShingleSize(), detector.getTimeDecay()); + List results = modelManager.getPreviewResults(features, detector); listener.onResponse(sample(parsePreviewResult(detector, features, results, null), maxPreviewResults)); } catch (Exception e) { onFailure(e, listener, detector.getId()); diff --git a/src/main/java/org/opensearch/ad/ml/ADColdStart.java b/src/main/java/org/opensearch/ad/ml/ADColdStart.java index d2db383f8..b4f329efa 100644 --- a/src/main/java/org/opensearch/ad/ml/ADColdStart.java +++ b/src/main/java/org/opensearch/ad/ml/ADColdStart.java @@ -13,6 +13,7 @@ import java.time.Clock; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import org.apache.logging.log4j.LogManager; @@ -171,7 +172,7 @@ protected List trainModelFromDataSegments( double[] firstPoint = pointSamples.get(0).getValueList(); if (firstPoint == null || firstPoint.length == 0) { - logger.info("Return early since data points must not be empty."); + logger.info("Return early since the first data point must not be empty."); return null; } @@ -216,6 +217,31 @@ protected List trainModelFromDataSegments( } AnomalyDetector detector = (AnomalyDetector) config; + applyRule(rcfBuilder, detector); + + // use build instead of new TRCF(Builder) because build method did extra validation and initialization + ThresholdedRandomCutForest trcf = rcfBuilder.build(); + + List imputed = new ArrayList<>(); + for (int i = 0; i < pointSamples.size(); i++) { + Sample dataSample = pointSamples.get(i); + double[] dataValue = dataSample.getValueList(); + // We don't keep missing values during cold start as the actual data may not be reconstructed during the early stage. + trcf.process(dataValue, dataSample.getDataEndTime().getEpochSecond()); + imputed.add(new Sample(dataValue, dataSample.getDataStartTime(), dataSample.getDataEndTime())); + } + + entityState.setModel(trcf); + + entityState.setLastUsedTime(clock.instant()); + + // save to checkpoint + checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); + + return pointSamples; + } + + public static void applyRule(ThresholdedRandomCutForest.Builder rcfBuilder, AnomalyDetector detector) { ThresholdArrays thresholdArrays = IgnoreSimilarExtractor.processDetectorRules(detector); if (thresholdArrays != null) { @@ -235,23 +261,5 @@ protected List trainModelFromDataSegments( rcfBuilder.ignoreNearExpectedFromBelowByRatio(thresholdArrays.ignoreSimilarFromBelowByRatio); } } - - // use build instead of new TRCF(Builder) because build method did extra validation and initialization - ThresholdedRandomCutForest trcf = rcfBuilder.build(); - - for (int i = 0; i < pointSamples.size(); i++) { - Sample dataSample = pointSamples.get(i); - double[] dataValue = dataSample.getValueList(); - trcf.process(dataValue, dataSample.getDataEndTime().getEpochSecond()); - } - - entityState.setModel(trcf); - - entityState.setLastUsedTime(clock.instant()); - - // save to checkpoint - checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); - - return pointSamples; } } diff --git a/src/main/java/org/opensearch/ad/ml/ADInferencer.java b/src/main/java/org/opensearch/ad/ml/ADInferencer.java new file mode 100644 index 000000000..26e6c032f --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/ADInferencer.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.ml; + +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.ml.Inferencer; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADInferencer extends + Inferencer { + + public ADInferencer( + ADModelManager modelManager, + Stats stats, + ADCheckpointDao checkpointDao, + ADColdStartWorker coldStartWorker, + ADSaveResultStrategy resultWriteWorker, + ADCacheProvider cache, + ThreadPool threadPool + ) { + super( + modelManager, + stats, + StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), + checkpointDao, + coldStartWorker, + resultWriteWorker, + cache, + threadPool, + AD_THREAD_POOL_NAME + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/ml/ADModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java index aa553a7bf..354b02557 100644 --- a/src/main/java/org/opensearch/ad/ml/ADModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -33,7 +33,9 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.ImputedFeatureResult; import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; @@ -47,13 +49,17 @@ import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap; +import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DateUtils; +import org.opensearch.timeseries.util.ModelUtil; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.config.ForestMode; import com.amazon.randomcutforest.config.Precision; import com.amazon.randomcutforest.config.TransformMethod; import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; @@ -137,7 +143,11 @@ public ADModelManager( this.initialAcceptFraction = rcfNumMinSamples * 1.0d / rcfNumSamplesInTree; } + @Deprecated /** + * used in RCFResultTransportAction to handle request from old node request. + * In the new logic, we switch to SingleStreamResultAction. + * * Returns to listener the RCF anomaly result using the specified model. * * @param detectorId ID of the detector @@ -194,7 +204,9 @@ private void getTRcfResult( result.getExpectedValuesList(), result.getLikelihoodOfValues(), result.getThreshold(), - result.getNumberOfTrees() + result.getNumberOfTrees(), + point, + null ) ); } catch (Exception e) { @@ -513,11 +525,10 @@ private void maintenanceForIterator( * Returns computed anomaly results for preview data points. * * @param features features of preview data points - * @param shingleSize model shingle size - * @return rcfTimeDecay rcf time decay + * @param detector Anomaly detector * @throws IllegalArgumentException when preview data points are not valid */ - public List getPreviewResults(Features features, int shingleSize, double rcfTimeDecay) { + public List getPreviewResults(Features features, AnomalyDetector detector) { double[][] dataPoints = features.getUnprocessedFeatures(); if (dataPoints.length < minPreviewSize) { throw new IllegalArgumentException("Insufficient data for preview results. Minimum required: " + minPreviewSize); @@ -528,11 +539,15 @@ public List getPreviewResults(Features features, int shingle String.format(Locale.ROOT, "time range size %d does not match data points size %d", timeRanges.size(), dataPoints.length) ); } + + int shingleSize = detector.getShingleSize(); + double rcfTimeDecay = detector.getTimeDecay(); + // Train RCF models and collect non-zero scores int baseDimension = dataPoints[0].length; // speed is important in preview. We don't want cx to wait too long. // thus use the default value of boundingBoxCacheFraction = 1 - ThresholdedRandomCutForest trcf = ThresholdedRandomCutForest + ThresholdedRandomCutForest.Builder trcfBuilder = ThresholdedRandomCutForest .builder() .randomSeed(0L) .dimensions(baseDimension * shingleSize) @@ -550,27 +565,32 @@ public List getPreviewResults(Features features, int shingle .transformMethod(TransformMethod.NORMALIZE) .alertOnce(true) .autoAdjust(true) - .internalShinglingEnabled(true) - .build(); + .internalShinglingEnabled(true); + + if (shingleSize > 1) { + trcfBuilder.forestMode(ForestMode.STREAMING_IMPUTE); + trcfBuilder = ModelColdStart.applyImputationMethod(detector, trcfBuilder); + } else { + // imputation with shingle size 1 is not meaningful + trcfBuilder.forestMode(ForestMode.STANDARD); + } + + ADColdStart.applyRule(trcfBuilder, detector); + + ThresholdedRandomCutForest trcf = trcfBuilder.build(); return IntStream.range(0, dataPoints.length).mapToObj(i -> { + // we don't have missing values in preview data. We have already filtered them out. double[] point = dataPoints[i]; // Get the data end epoch milliseconds corresponding to this index and convert it to seconds long timestampSecs = timeRanges.get(i).getValue() / 1000; AnomalyDescriptor descriptor = trcf.process(point, timestampSecs); // Use the timestamp here - return new ThresholdingResult( - descriptor.getAnomalyGrade(), - descriptor.getDataConfidence(), - descriptor.getRCFScore(), - descriptor.getTotalUpdates(), - descriptor.getRelativeIndex(), - normalizeAttribution(trcf.getForest(), descriptor.getRelevantAttribution()), - descriptor.getPastValues(), - descriptor.getExpectedValuesList(), - descriptor.getLikelihoodOfValues(), - descriptor.getThreshold(), - rcfNumTrees - ); + + if (descriptor != null) { + return toResult(trcf.getForest(), descriptor, point, false, detector); + } + + return null; }).collect(Collectors.toList()); } @@ -623,7 +643,15 @@ protected ThresholdingResult createEmptyResult() { } @Override - protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { + protected ThresholdingResult toResult( + RandomCutForest rcf, + AnomalyDescriptor anomalyDescriptor, + double[] point, + boolean isImputed, + Config config + ) { + ImputedFeatureResult result = ModelUtil.calculateImputedFeatures(anomalyDescriptor, point, isImputed, config); + return new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), anomalyDescriptor.getDataConfidence(), @@ -635,7 +663,9 @@ protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor ano anomalyDescriptor.getExpectedValuesList(), anomalyDescriptor.getLikelihoodOfValues(), anomalyDescriptor.getThreshold(), - rcfNumTrees + rcfNumTrees, + result.getActual(), + result.getIsFeatureImputed() ); } } diff --git a/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java b/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java index a2da03f51..f4b7c1fb0 100644 --- a/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java +++ b/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java @@ -136,6 +136,11 @@ public class ThresholdingResult extends IntermediateResult { protected final double confidence; + // actual or imputed data + private double[] currentData; + + private boolean[] featureImputed; + /** * Constructor for default empty value or backward compatibility. * In terms of bwc, when an old node sends request for threshold results, @@ -148,7 +153,7 @@ public class ThresholdingResult extends IntermediateResult { * saving or not. */ public ThresholdingResult(double grade, double confidence, double rcfScore) { - this(grade, confidence, rcfScore, 0, 0, null, null, null, null, 0, 0); + this(grade, confidence, rcfScore, 0, 0, null, null, null, null, 0, 0, null, null); } public ThresholdingResult( @@ -162,7 +167,9 @@ public ThresholdingResult( double[][] expectedValuesList, double[] likelihoodOfValues, double threshold, - int forestSize + int forestSize, + double[] currentData, + boolean[] featureImputed ) { super(totalUpdates, rcfScore); this.confidence = confidence; @@ -175,6 +182,9 @@ public ThresholdingResult( this.likelihoodOfValues = likelihoodOfValues; this.threshold = threshold; this.forestSize = forestSize; + this.currentData = currentData; + this.featureImputed = featureImputed; + } /** @@ -223,12 +233,22 @@ public int getForestSize() { return forestSize; } + public double[] getCurrentData() { + return currentData; + } + + public boolean isFeatureImputed(int i) { + return featureImputed[i]; + } + @Override public boolean equals(Object o) { - if (!super.equals(o)) + if (!super.equals(o)) { return false; - if (getClass() != o.getClass()) + } + if (getClass() != o.getClass()) { return false; + } ThresholdingResult that = (ThresholdingResult) o; return Double.doubleToLongBits(confidence) == Double.doubleToLongBits(that.confidence) && Double.doubleToLongBits(this.grade) == Double.doubleToLongBits(that.grade) @@ -238,7 +258,9 @@ public boolean equals(Object o) { && Arrays.deepEquals(expectedValuesList, that.expectedValuesList) && Arrays.equals(likelihoodOfValues, that.likelihoodOfValues) && Double.doubleToLongBits(threshold) == Double.doubleToLongBits(that.threshold) - && forestSize == that.forestSize; + && forestSize == that.forestSize + && Arrays.equals(currentData, that.currentData) + && Arrays.equals(featureImputed, that.featureImputed); } @Override @@ -254,7 +276,9 @@ public int hashCode() { Arrays.deepHashCode(expectedValuesList), Arrays.hashCode(likelihoodOfValues), threshold, - forestSize + forestSize, + Arrays.hashCode(currentData), + Arrays.hashCode(featureImputed) ); } @@ -271,6 +295,8 @@ public String toString() { .append("likelihoodOfValues", Arrays.toString(likelihoodOfValues)) .append("threshold", threshold) .append("forestSize", forestSize) + .append("currentData", Arrays.toString(currentData)) + .append("featureImputed", Arrays.toString(featureImputed)) .toString(); } @@ -330,7 +356,9 @@ public List toIndexableResults( pastValues, expectedValuesList, likelihoodOfValues, - threshold + threshold, + currentData, + featureImputed ) ); } diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index bdfe4eb3c..868317bab 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -40,6 +40,7 @@ import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.DataUtil; import org.opensearch.timeseries.util.ParseUtils; import com.google.common.base.Objects; @@ -66,6 +67,7 @@ public class AnomalyResult extends IndexableResult { public static final String THRESHOLD_FIELD = "threshold"; // unused currently. added since odfe 1.4 public static final String IS_ANOMALY_FIELD = "is_anomaly"; + public static final String FEATURE_IMPUTED = "feature_imputed"; private final Double anomalyScore; private final Double anomalyGrade; @@ -193,6 +195,9 @@ So if we detect anomaly late, we get the baseDimension values from the past (cur */ private final String modelId; + // whether a feature value is imputed or not + private List featureImputed; + // used when indexing exception or error or an empty result public AnomalyResult( String detectorId, @@ -228,6 +233,7 @@ public AnomalyResult( null, null, null, + null, null ); } @@ -252,7 +258,8 @@ public AnomalyResult( List relevantAttribution, List pastValues, List expectedValuesList, - Double threshold + Double threshold, + List featureImputed ) { super( configId, @@ -276,6 +283,7 @@ public AnomalyResult( this.pastValues = pastValues; this.expectedValuesList = expectedValuesList; this.threshold = threshold; + this.featureImputed = featureImputed; } /** @@ -302,6 +310,8 @@ public AnomalyResult( * @param expectedValuesList Expected values * @param likelihoodOfValues Likelihood of the expected values * @param threshold Current threshold + * @param currentData imputed data if any + * @param featureImputed whether feature is imputed or not * @return the converted AnomalyResult instance */ public static AnomalyResult fromRawTRCFResult( @@ -326,14 +336,17 @@ public static AnomalyResult fromRawTRCFResult( double[] pastValues, double[][] expectedValuesList, double[] likelihoodOfValues, - Double threshold + Double threshold, + double[] currentData, + boolean[] featureImputed ) { List convertedRelevantAttribution = null; List convertedPastValuesList = null; List convertedExpectedValues = null; + int featureSize = featureData == null ? 0 : featureData.size(); + if (grade > 0) { - int featureSize = featureData.size(); if (relevantAttribution != null) { if (relevantAttribution.length == featureSize) { convertedRelevantAttribution = new ArrayList<>(featureSize); @@ -352,11 +365,21 @@ public static AnomalyResult fromRawTRCFResult( } } - if (pastValues != null) { + // it is possible pastValues is not null but relativeIndex is null. It would happen when the imputation ends in a continuous + // area. + if (pastValues != null && relativeIndex != null && relativeIndex < 0) { if (pastValues.length == featureSize) { convertedPastValuesList = new ArrayList<>(featureSize); for (int j = 0; j < featureSize; j++) { - convertedPastValuesList.add(new DataByFeatureId(featureData.get(j).getFeatureId(), pastValues[j])); + // When impute missing values, the first imputation will generate NaN value, but OS's double type won't accept NaN + // value. + // So we will break out of the loop and not save a past value. + if (Double.isNaN(pastValues[j]) || Double.isInfinite(pastValues[j])) { + convertedPastValuesList = null; + break; + } else { + convertedPastValuesList.add(new DataByFeatureId(featureData.get(j).getFeatureId(), pastValues[j])); + } } } else { LOG @@ -404,6 +427,20 @@ public static AnomalyResult fromRawTRCFResult( } } + List featureImputedList = new ArrayList<>(); + if (featureImputed != null) { + for (int i = 0; i < featureImputed.length; i++) { + FeatureData featureItem = featureData.get(i); + // round to 3rd decimal places + if (featureImputed[i]) { + featureItem.setData(DataUtil.roundDouble(currentData[i], 3)); + } else { + featureItem.setData(DataUtil.roundDouble(featureItem.getData(), 3)); + } + featureImputedList.add(new FeatureImputed(featureItem.getFeatureId(), featureImputed[i])); + } + } + return new AnomalyResult( detectorId, taskId, @@ -420,13 +457,14 @@ public static AnomalyResult fromRawTRCFResult( user, schemaVersion, modelId, - (relativeIndex == null || dataStartTime == null) + (relativeIndex == null || dataStartTime == null || relativeIndex >= 0) ? null : Instant.ofEpochMilli(dataStartTime.toEpochMilli() + relativeIndex * intervalMillis), convertedRelevantAttribution, convertedPastValuesList, convertedExpectedValues, - threshold + threshold, + featureImputedList ); } @@ -470,6 +508,16 @@ public AnomalyResult(StreamInput input) throws IOException { } this.threshold = input.readOptionalDouble(); + + int inputLength = input.readVInt(); + if (inputLength > 0) { + this.featureImputed = new ArrayList<>(); + for (int i = 0; i < inputLength; i++) { + featureImputed.add(new FeatureImputed(input)); + } + } else { + this.featureImputed = null; + } } @Override @@ -545,6 +593,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (threshold != null && !threshold.isNaN()) { xContentBuilder.field(THRESHOLD_FIELD, threshold); } + if (featureImputed != null && featureImputed.size() > 0) { + xContentBuilder.array(FEATURE_IMPUTED, featureImputed.toArray()); + } return xContentBuilder.endObject(); } @@ -569,6 +620,7 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { List pastValues = new ArrayList<>(); List expectedValues = new ArrayList<>(); Double threshold = null; + List featureImputed = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -648,6 +700,13 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { case THRESHOLD_FIELD: threshold = parser.doubleValue(); break; + case FEATURE_IMPUTED: + featureImputed = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + featureImputed.add(FeatureImputed.parse(parser)); + } + break; default: parser.skipChildren(); break; @@ -674,17 +733,20 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { relavantAttribution, pastValues, expectedValues, - threshold + threshold, + featureImputed ); } @Generated @Override public boolean equals(Object o) { - if (!super.equals(o)) + if (!super.equals(o)) { return false; - if (getClass() != o.getClass()) + } + if (getClass() != o.getClass()) { return false; + } AnomalyResult that = (AnomalyResult) o; return Objects.equal(modelId, that.modelId) && Objects.equal(confidence, that.confidence) @@ -694,7 +756,8 @@ public boolean equals(Object o) { && Objects.equal(relevantAttribution, that.relevantAttribution) && Objects.equal(pastValues, that.pastValues) && Objects.equal(expectedValuesList, that.expectedValuesList) - && Objects.equal(threshold, that.threshold); + && Objects.equal(threshold, that.threshold) + && Objects.equal(featureImputed, that.featureImputed); } @Generated @@ -712,7 +775,8 @@ public int hashCode() { relevantAttribution, pastValues, expectedValuesList, - threshold + threshold, + featureImputed ); return result; } @@ -732,6 +796,7 @@ public String toString() { .append("pastValues", pastValues) .append("expectedValuesList", StringUtils.join(expectedValuesList, "|")) .append("threshold", threshold) + .append("featureImputed", featureImputed) .toString(); } @@ -775,6 +840,10 @@ public String getModelId() { return modelId; } + public List getFeatureImputed() { + return featureImputed; + } + /** * Anomaly result index consists of overwhelmingly (99.5%) zero-grade non-error documents. * This function exclude the majority case. @@ -825,6 +894,15 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalDouble(threshold); + + if (featureImputed != null) { + out.writeVInt(featureImputed.size()); + for (FeatureImputed imputed : featureImputed) { + imputed.writeTo(out); + } + } else { + out.writeVInt(0); + } } public static AnomalyResult getDummyResult() { diff --git a/src/main/java/org/opensearch/ad/model/FeatureImputed.java b/src/main/java/org/opensearch/ad/model/FeatureImputed.java new file mode 100644 index 000000000..c6c7aeada --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/FeatureImputed.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import com.google.common.base.Objects; + +/** + * Feature imputed and its Id + * + */ +public class FeatureImputed implements ToXContentObject, Writeable { + + public static final String FEATURE_ID_FIELD = "feature_id"; + public static final String IMPUTED_FIELD = "imputed"; + + protected String featureId; + protected Boolean imputed; + + public FeatureImputed(String featureId, Boolean imputed) { + this.featureId = featureId; + this.imputed = imputed; + } + + public FeatureImputed(StreamInput input) throws IOException { + this.featureId = input.readString(); + this.imputed = input.readBoolean(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject().field(FEATURE_ID_FIELD, featureId).field(IMPUTED_FIELD, imputed); + return xContentBuilder.endObject(); + } + + public static FeatureImputed parse(XContentParser parser) throws IOException { + String featureId = null; + Boolean imputed = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FEATURE_ID_FIELD: + featureId = parser.text(); + break; + case IMPUTED_FIELD: + imputed = parser.booleanValue(); + break; + default: + // the unknown field and it's children should be ignored + parser.skipChildren(); + break; + } + } + return new FeatureImputed(featureId, imputed); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FeatureImputed that = (FeatureImputed) o; + return Objects.equal(getFeatureId(), that.getFeatureId()) && Objects.equal(isImputed(), that.isImputed()); + } + + @Override + public int hashCode() { + return Objects.hashCode(getFeatureId(), isImputed()); + } + + public String getFeatureId() { + return featureId; + } + + public Boolean isImputed() { + return imputed; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureId); + out.writeBoolean(imputed); + } + +} diff --git a/src/main/java/org/opensearch/ad/model/ImputedFeatureResult.java b/src/main/java/org/opensearch/ad/model/ImputedFeatureResult.java new file mode 100644 index 000000000..5913be4e5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ImputedFeatureResult.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +public class ImputedFeatureResult { + boolean[] isFeatureImputed; + double[] actual; + + public ImputedFeatureResult(boolean[] isFeatureImputed, double[] actual) { + this.isFeatureImputed = isFeatureImputed; + this.actual = actual; + } + + public boolean[] getIsFeatureImputed() { + return isFeatureImputed; + } + + public double[] getActual() { + return actual; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java index 40ea61ae6..7a0fe75c8 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java @@ -24,10 +24,10 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.stats.ADStats; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; @@ -38,7 +38,6 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; -import org.opensearch.timeseries.stats.StatNames; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -55,7 +54,7 @@ * */ public class ADCheckpointReadWorker extends - CheckpointReadWorker { + CheckpointReadWorker { public static final String WORKER_NAME = "ad-checkpoint-read"; public ADCheckpointReadWorker( @@ -77,12 +76,10 @@ public ADCheckpointReadWorker( ADCheckpointDao checkpointDao, ADColdStartWorker entityColdStartQueue, NodeStateManager stateManager, - ADIndexManagement indexUtil, Provider cacheProvider, Duration stateTtl, ADCheckpointWriteWorker checkpointWriteQueue, - ADStats adStats, - ADSaveResultStrategy resultWriteWorker + ADInferencer inferencer ) { super( WORKER_NAME, @@ -105,17 +102,14 @@ public ADCheckpointReadWorker( checkpointDao, entityColdStartQueue, stateManager, - indexUtil, cacheProvider, stateTtl, checkpointWriteQueue, - adStats, AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, ADCommonName.CHECKPOINT_INDEX_NAME, - StatNames.AD_MODEL_CORRUTPION_COUNT, AnalysisType.AD, - resultWriteWorker + inferencer ); } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java index 0abd7527d..aa92704fe 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java @@ -23,6 +23,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyResult; @@ -54,7 +55,7 @@ * */ public class ADColdEntityWorker extends - ColdEntityWorker { + ColdEntityWorker { public static final String WORKER_NAME = "ad-cold-entity"; public ADColdEntityWorker( diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java index 09cd82a4a..7f88df4c1 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java @@ -118,7 +118,7 @@ protected ModelState createEmptyState(FeatureRequest null, modelId, configId, - ModelManager.ModelType.RCFCASTER.getName(), + ModelManager.ModelType.TRCF.getName(), clock, 0, request.getEntity(), @@ -131,6 +131,9 @@ protected AnomalyResult createIndexableResult(Config config, String taskId, Stri return new AnomalyResult( config.getId(), taskId, + Double.NaN, + Double.NaN, + Double.NaN, ParseUtils.getFeatureData(entry.getValueList(), config), entry.getDataStartTime(), entry.getDataEndTime(), @@ -140,7 +143,13 @@ protected AnomalyResult createIndexableResult(Config config, String taskId, Stri entity, config.getUser(), config.getSchemaVersion(), - modelId + modelId, + null, + null, + null, + null, + null, + null ); } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java index d82eab34a..cac437523 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java @@ -71,7 +71,6 @@ public void saveResult( taskId, null ); - for (AnomalyResult r : indexableResults) { saveResult(r, config); } diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index ef2b7d754..1219107c4 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -92,6 +92,11 @@ public abstract class AbstractAnomalyDetectorActionHandler REQUEST_TIMEOUT = Setting .positiveTimeSetting( "opendistro.anomaly_detection.request_timeout", - TimeValue.timeValueSeconds(10), + TimeValue.timeValueSeconds(60), Setting.Property.NodeScope, Setting.Property.Dynamic, Setting.Property.Deprecated diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java index acccb0dd2..4623ad4e3 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java @@ -15,6 +15,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import org.opensearch.ad.ml.ADColdStart; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.timeseries.ml.ModelColdStart; @@ -85,6 +86,8 @@ protected ADBatchTaskCache(ADTask adTask) { rcfBuilder.forestMode(ForestMode.STANDARD); } + ADColdStart.applyRule(rcfBuilder, detector); + rcfModel = rcfBuilder.build(); this.thresholdModelTrained = false; } diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index 6ff20287c..20c87d1b6 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -25,6 +25,7 @@ import java.time.Clock; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,6 +44,7 @@ import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; @@ -87,6 +89,7 @@ import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.FeatureData; @@ -105,7 +108,6 @@ import org.opensearch.transport.TransportService; import com.amazon.randomcutforest.RandomCutForest; -import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -1087,62 +1089,61 @@ private void detectAnomaly( List anomalyResults = new ArrayList<>(); long intervalEndTime = pieceStartTime; + AnomalyDetector detector = adTask.getDetector(); for (int i = 0; i < pieceSize && intervalEndTime < dataEndTime; i++) { Optional dataPoint = dataPoints.containsKey(intervalEndTime) ? dataPoints.get(intervalEndTime) : Optional.empty(); intervalEndTime = intervalEndTime + interval; + Instant pieceDataStartTime = Instant.ofEpochMilli(intervalEndTime - interval); + Instant pieceDataEndTime = Instant.ofEpochMilli(intervalEndTime); - if (dataPoint.isEmpty()) { + if (dataPoint.isEmpty() && detector.getImputationOption() == null) { AnomalyResult anomalyResult = new AnomalyResult( adTask.getConfigId(), adTask.getConfigLevelTaskId(), null, - Instant.ofEpochMilli(intervalEndTime - interval), - Instant.ofEpochMilli(intervalEndTime), + pieceDataStartTime, + pieceDataEndTime, executeStartTime, Instant.now(), "No data in current detection window", Optional.ofNullable(adTask.getEntity()), - adTask.getDetector().getUser(), + detector.getUser(), anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), adTask.getEntityModelId() ); anomalyResults.add(anomalyResult); } else { - List featureData = ParseUtils.getFeatureData(dataPoint.get(), adTask.getDetector()); - // 0 is placeholder for timestamp. In the future, we will add - // data time stamp there. - AnomalyDescriptor descriptor = trcf.process(dataPoint.get(), intervalEndTime); - double score = descriptor.getRCFScore(); - if (!adTaskCacheManager.isThresholdModelTrained(taskId) && score > 0) { + // dataPoint is empty or or dataPoint may have partial results (missing feature value is Double.NaN) or imputation option is + // not null + double[] toScore = null; + if (dataPoint.isEmpty()) { + toScore = new double[detector.getEnabledFeatureIds().size()]; + Arrays.fill(toScore, Double.NaN); + } else { + toScore = dataPoint.get(); + } + ThresholdingResult thresholdingResult = modelManager + .score(new Sample(toScore, pieceDataStartTime, pieceDataEndTime), detector, trcf); + if (!adTaskCacheManager.isThresholdModelTrained(taskId) && thresholdingResult.getRcfScore() > 0) { adTaskCacheManager.setThresholdModelTrained(taskId, true); } + List featureData = ParseUtils.getFeatureData(toScore, adTask.getDetector()); - AnomalyResult anomalyResult = AnomalyResult - .fromRawTRCFResult( - adTask.getConfigId(), - adTask.getDetector().getIntervalInMilliseconds(), - adTask.getConfigLevelTaskId(), - score, - descriptor.getAnomalyGrade(), - descriptor.getDataConfidence(), - featureData, - Instant.ofEpochMilli(intervalEndTime - interval), - Instant.ofEpochMilli(intervalEndTime), + List indexableResults = thresholdingResult + .toIndexableResults( + detector, + pieceDataStartTime, + pieceDataEndTime, executeStartTime, Instant.now(), - null, + featureData, Optional.ofNullable(adTask.getEntity()), - adTask.getDetector().getUser(), anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), adTask.getEntityModelId(), - modelManager.normalizeAttribution(trcf.getForest(), descriptor.getRelevantAttribution()), - descriptor.getRelativeIndex(), - descriptor.getPastValues(), - descriptor.getExpectedValuesList(), - descriptor.getLikelihoodOfValues(), - descriptor.getThreshold() + adTask.getConfigLevelTaskId(), + null ); - anomalyResults.add(anomalyResult); + anomalyResults.addAll(indexableResults); } } @@ -1300,7 +1301,14 @@ private void runNextPiece( ); }, TimeValue.timeValueSeconds(pieceIntervalSeconds), AD_BATCH_TASK_THREAD_POOL_NAME); } else { - logger.info("AD task finished for detector {}, task id: {}", detectorId, taskId); + logger + .info( + "AD task finished for detector {}, task id: {}, pieceStartTime: {}, dataEndTime: {}", + detectorId, + taskId, + pieceStartTime, + dataEndTime + ); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); adTaskManager .updateTask( diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeAction.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeAction.java new file mode 100644 index 000000000..da1537ca5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.ADCommonValue; + +public class ADHCImputeAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "impute/hc"; + public static final ADHCImputeAction INSTANCE = new ADHCImputeAction(); + + private ADHCImputeAction() { + super(NAME, ADHCImputeNodesResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeRequest.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeRequest.java new file mode 100644 index 000000000..2e85d0800 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeRequest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class ADHCImputeNodeRequest extends TransportRequest { + private final ADHCImputeRequest request; + + public ADHCImputeNodeRequest(StreamInput in) throws IOException { + super(in); + this.request = new ADHCImputeRequest(in); + } + + public ADHCImputeNodeRequest(ADHCImputeRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } + + public ADHCImputeRequest getRequest() { + return request; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeResponse.java new file mode 100644 index 000000000..ef62ab491 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodeResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADHCImputeNodeResponse extends BaseNodeResponse { + + Exception previousException; + + public ADHCImputeNodeResponse(DiscoveryNode node, Exception previousException) { + super(node); + this.previousException = previousException; + } + + public ADHCImputeNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.previousException = in.readException(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (previousException != null) { + out.writeBoolean(true); + out.writeException(previousException); + } else { + out.writeBoolean(false); + } + } + + public static ADHCImputeNodeResponse readNodeResponse(StreamInput in) throws IOException { + return new ADHCImputeNodeResponse(in); + } + + public Exception getPreviousException() { + return previousException; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeNodesResponse.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodesResponse.java new file mode 100644 index 000000000..87874d13d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeNodesResponse.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADHCImputeNodesResponse extends BaseNodesResponse { + public ADHCImputeNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ADHCImputeNodeResponse::readNodeResponse), in.readList(FailedNodeException::new)); + } + + public ADHCImputeNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ADHCImputeNodeResponse::readNodeResponse); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeRequest.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeRequest.java new file mode 100644 index 000000000..9a082b538 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeRequest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADHCImputeRequest extends BaseNodesRequest { + private final String configId; + private final String taskId; + private final long dataStartMillis; + private final long dataEndMillis; + + public ADHCImputeRequest(String configId, String taskId, long startMillis, long endMillis, DiscoveryNode... nodes) { + super(nodes); + this.configId = configId; + this.taskId = taskId; + this.dataStartMillis = startMillis; + this.dataEndMillis = endMillis; + } + + public ADHCImputeRequest(StreamInput in) throws IOException { + super(in); + this.configId = in.readString(); + this.taskId = in.readOptionalString(); + this.dataStartMillis = in.readLong(); + this.dataEndMillis = in.readLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + out.writeOptionalString(taskId); + out.writeLong(dataStartMillis); + out.writeLong(dataEndMillis); + } + + public String getConfigId() { + return configId; + } + + public String getTaskId() { + return taskId; + } + + public long getDataStartMillis() { + return dataStartMillis; + } + + public long getDataEndMillis() { + return dataEndMillis; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java new file mode 100644 index 000000000..6f0e442bf --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADInferencer; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ActionListenerExecutor; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADHCImputeTransportAction extends + TransportNodesAction { + private static final Logger LOG = LogManager.getLogger(ADHCImputeTransportAction.class); + + private ADCacheProvider cache; + private NodeStateManager nodeStateManager; + private ADInferencer adInferencer; + + @Inject + public ADHCImputeTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADCacheProvider priorityCache, + NodeStateManager nodeStateManager, + ADInferencer adInferencer + ) { + super( + ADHCImputeAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ADHCImputeRequest::new, + ADHCImputeNodeRequest::new, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + ADHCImputeNodeResponse.class + ); + this.cache = priorityCache; + this.nodeStateManager = nodeStateManager; + this.adInferencer = adInferencer; + } + + @Override + protected ADHCImputeNodeRequest newNodeRequest(ADHCImputeRequest request) { + return new ADHCImputeNodeRequest(request); + } + + @Override + protected ADHCImputeNodeResponse newNodeResponse(StreamInput response) throws IOException { + return new ADHCImputeNodeResponse(response); + } + + @Override + protected ADHCImputeNodesResponse newResponse( + ADHCImputeRequest request, + List responses, + List failures + ) { + return new ADHCImputeNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest) { + String configId = nodeRequest.getRequest().getConfigId(); + nodeStateManager.getConfig(configId, AnalysisType.AD, ActionListenerExecutor.wrap(configOptional -> { + if (configOptional.isEmpty()) { + LOG.warn(String.format(Locale.ROOT, "cannot find config %s", configId)); + return; + } + Config config = configOptional.get(); + long windowDelayMillis = ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); + int featureSize = config.getEnabledFeatureIds().size(); + long dataEndMillis = nodeRequest.getRequest().getDataEndMillis(); + long dataStartMillis = nodeRequest.getRequest().getDataStartMillis(); + long executionEndTime = dataEndMillis + windowDelayMillis; + String taskId = nodeRequest.getRequest().getTaskId(); + for (ModelState modelState : cache.get().getAllModels(configId)) { + // execution end time (when job starts execution in this interval) > last used time => the model state is updated in + // previous intervals + if (executionEndTime > modelState.getLastUsedTime().toEpochMilli()) { + double[] nanArray = new double[featureSize]; + Arrays.fill(nanArray, Double.NaN); + adInferencer + .process( + new Sample(nanArray, Instant.ofEpochMilli(dataStartMillis), Instant.ofEpochMilli(dataEndMillis)), + modelState, + config, + taskId + ); + } + } + }, e -> nodeStateManager.setException(configId, e), threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME))); + + Optional previousException = nodeStateManager.fetchExceptionAndClear(configId); + + if (previousException.isPresent()) { + return new ADHCImputeNodeResponse(clusterService.localNode(), previousException.get()); + } else { + return new ADHCImputeNodeResponse(clusterService.localNode(), null); + } + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java index b0f01b996..b7564b38e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java @@ -22,10 +22,12 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; @@ -45,7 +47,6 @@ public class ADResultProcessor extends public ADResultProcessor( Setting requestTimeoutSetting, - float intervalRatioForRequests, String entityResultAction, StatNames hcRequestCountStat, Settings settings, @@ -65,7 +66,6 @@ public ADResultProcessor( ) { super( requestTimeoutSetting, - intervalRatioForRequests, entityResultAction, hcRequestCountStat, settings, @@ -102,4 +102,33 @@ protected AnomalyResultResponse createResultResponse( ) { return new AnomalyResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); } + + @Override + protected void imputeHC(long dataStartTime, long dataEndTime, String configID, String taskId) { + LOG + .info( + "Sending an HC impute request to process data from timestamp {} to {} for config {}", + dataStartTime, + dataEndTime, + configID + ); + + DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalVersion(); + + client + .execute( + ADHCImputeAction.INSTANCE, + new ADHCImputeRequest(configID, taskId, dataStartTime, dataEndTime, dataNodes), + ActionListener.wrap(hcImputeResponse -> { + for (final ADHCImputeNodeResponse nodeResponse : hcImputeResponse.getNodes()) { + if (nodeResponse.getPreviousException() != null) { + nodeStateManager.setException(configID, nodeResponse.getPreviousException()); + } + } + }, e -> { + LOG.warn("fail to HC impute", e); + nodeStateManager.setException(configID, e); + }) + ); + } } diff --git a/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java index 983c089b5..6995c5b36 100644 --- a/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADSingleStreamResultTransportAction.java @@ -13,6 +13,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyResult; @@ -21,24 +22,22 @@ import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.ratelimit.ADColdStartWorker; import org.opensearch.ad.ratelimit.ADResultWriteRequest; -import org.opensearch.ad.ratelimit.ADResultWriteWorker; import org.opensearch.ad.ratelimit.ADSaveResultStrategy; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.common.inject.Inject; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.ratelimit.RequestPriority; -import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.transport.AbstractSingleStreamResultTransportAction; import org.opensearch.transport.TransportService; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; public class ADSingleStreamResultTransportAction extends - AbstractSingleStreamResultTransportAction { + AbstractSingleStreamResultTransportAction { @Inject public ADSingleStreamResultTransportAction( @@ -48,11 +47,8 @@ public ADSingleStreamResultTransportAction( ADCacheProvider cache, NodeStateManager stateManager, ADCheckpointReadWorker checkpointReadQueue, - ADModelManager modelManager, - ADIndexManagement indexUtil, - ADResultWriteWorker resultWriteQueue, - ADStats stats, - ADColdStartWorker adColdStartQueue + ADInferencer inferencer, + ThreadPool threadPool ) { super( transportService, @@ -61,15 +57,11 @@ public ADSingleStreamResultTransportAction( cache, stateManager, checkpointReadQueue, - modelManager, - indexUtil, - resultWriteQueue, - stats, - adColdStartQueue, ADSingleStreamResultAction.NAME, - ADIndex.RESULT, AnalysisType.AD, - StatNames.AD_MODEL_CORRUTPION_COUNT.getName() + inferencer, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java index 8708cb92a..072795ef2 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java @@ -352,7 +352,12 @@ public List toIndexableResults( pastValues, expectedValuesList, likelihoodOfValues, - threshold + threshold, + // Starting from version 2.15, this class is used to store job execution errors, not actual results, + // as the single stream has been changed to async mode. The job no longer waits for results before returning. + // Therefore, we set the following two fields to null, as we will not record any imputed fields. + null, + null ) ); } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 1ee5b3ab8..d8a726d5d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -44,7 +44,6 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.feature.FeatureManager; -import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.transport.ResultProcessor; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -85,7 +84,6 @@ public AnomalyResultTransportAction( super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); this.resultProcessor = new ADResultProcessor( AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, - TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, EntityADResultAction.NAME, StatNames.AD_HC_EXECUTE_REQUEST_COUNT, settings, diff --git a/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java index b5d644c21..3712b9f1f 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java @@ -24,6 +24,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyResult; @@ -32,7 +33,6 @@ import org.opensearch.ad.ratelimit.ADColdEntityWorker; import org.opensearch.ad.ratelimit.ADColdStartWorker; import org.opensearch.ad.ratelimit.ADSaveResultStrategy; -import org.opensearch.ad.stats.ADStats; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; @@ -44,7 +44,6 @@ import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.transport.EntityResultProcessor; import org.opensearch.timeseries.transport.EntityResultRequest; import org.opensearch.timeseries.util.ExceptionUtil; @@ -78,21 +77,17 @@ public class EntityADResultTransportAction extends HandledTransportAction cache; private final NodeStateManager stateManager; private ThreadPool threadPool; - private EntityResultProcessor intervalDataProcessor; + private EntityResultProcessor intervalDataProcessor; private final ADCacheProvider entityCache; - private final ADModelManager manager; - private final ADStats timeSeriesStats; - private final ADColdStartWorker entityColdStartWorker; private final ADCheckpointReadWorker checkpointReadQueue; private final ADColdEntityWorker coldEntityQueue; - private final ADSaveResultStrategy adSaveResultStategy; + private final ADInferencer inferencer; @Inject public EntityADResultTransportAction( ActionFilters actionFilters, TransportService transportService, - ADModelManager manager, CircuitBreakerService adCircuitBreakerService, ADCacheProvider entityCache, NodeStateManager stateManager, @@ -100,9 +95,7 @@ public EntityADResultTransportAction( ADCheckpointReadWorker checkpointReadQueue, ADColdEntityWorker coldEntityQueue, ThreadPool threadPool, - ADColdStartWorker entityColdStartWorker, - ADStats timeSeriesStats, - ADSaveResultStrategy adSaveResultStategy + ADInferencer inferencer ) { super(EntityADResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); this.adCircuitBreakerService = adCircuitBreakerService; @@ -111,13 +104,10 @@ public EntityADResultTransportAction( this.threadPool = threadPool; this.entityCache = entityCache; - this.manager = manager; - this.timeSeriesStats = timeSeriesStats; - this.entityColdStartWorker = entityColdStartWorker; this.checkpointReadQueue = checkpointReadQueue; this.coldEntityQueue = coldEntityQueue; - this.adSaveResultStategy = adSaveResultStategy; this.intervalDataProcessor = null; + this.inferencer = inferencer; } @Override @@ -151,13 +141,11 @@ protected void doExecute(Task task, EntityResultRequest request, ActionListener< this.intervalDataProcessor = new EntityResultProcessor<>( entityCache, - manager, - timeSeriesStats, - entityColdStartWorker, checkpointReadQueue, coldEntityQueue, - adSaveResultStategy, - StatNames.AD_MODEL_CORRUTPION_COUNT + inferencer, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); stateManager diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java index 417696609..5aa362e1d 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Objects; +import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -27,6 +28,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.annotation.Generated; import org.opensearch.timeseries.common.exception.VersionException; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; @@ -197,10 +199,12 @@ public Integer getAvailableTaskSLots() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } ForwardADTaskRequest request = (ForwardADTaskRequest) o; return Objects.equals(detector, request.detector) && Objects.equals(adTask, request.adTask) @@ -215,4 +219,18 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(detector, adTask, detectionDateRange, staleRunningEntities, user, availableTaskSlots, adTaskAction); } + + @Generated + @Override + public String toString() { + return new ToStringBuilder(this) + .append("detector", detector) + .append("adTask", adTask) + .append("detectionDateRange", detectionDateRange) + .append("staleRunningEntities", staleRunningEntities) + .append("user", user) + .append("availableTaskSlots", availableTaskSlots) + .append("adTaskAction", adTaskAction) + .toString(); + } } diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java index 1ac134768..149c934f2 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java @@ -47,6 +47,7 @@ protected void bulk(ADResultBulkRequest currentBulkRequest, ActionListenerwrap(response -> { LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); listener.onResponse(response); diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastInferencer.java b/src/main/java/org/opensearch/forecast/ml/ForecastInferencer.java new file mode 100644 index 000000000..793b21995 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastInferencer.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ml; + +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME; + +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.ml.Inferencer; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastInferencer extends + Inferencer { + + public ForecastInferencer( + ForecastModelManager modelManager, + Stats stats, + ForecastCheckpointDao checkpointDao, + ForecastColdStartWorker coldStartWorker, + ForecastSaveResultStrategy resultWriteWorker, + ForecastCacheProvider cache, + ThreadPool threadPool + ) { + super( + modelManager, + stats, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName(), + checkpointDao, + coldStartWorker, + resultWriteWorker, + cache, + threadPool, + FORECAST_THREAD_POOL_NAME + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java index 438c9bdde..3e15014d5 100644 --- a/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java +++ b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java @@ -21,6 +21,7 @@ import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.model.Config; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; @@ -49,7 +50,13 @@ protected RCFCasterResult createEmptyResult() { } @Override - protected RCFCasterResult toResult(RandomCutForest forecast, RCFDescriptor castDescriptor) { + protected RCFCasterResult toResult( + RandomCutForest forecast, + RCFDescriptor castDescriptor, + double[] point, + boolean isImputed, + Config config + ) { if (castDescriptor instanceof ForecastDescriptor) { ForecastDescriptor forecastDescriptor = (ForecastDescriptor) castDescriptor; // Use forecastDescriptor in the rest of your method diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java index 5bbcb4e1e..652b38ec2 100644 --- a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java @@ -22,22 +22,21 @@ import org.opensearch.forecast.indices.ForecastIndexManagement; import org.opensearch.forecast.ml.ForecastCheckpointDao; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastInferencer; import org.opensearch.forecast.ml.ForecastModelManager; import org.opensearch.forecast.ml.RCFCasterResult; import org.opensearch.forecast.model.ForecastResult; -import org.opensearch.forecast.stats.ForecastStats; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; -import org.opensearch.timeseries.stats.StatNames; import com.amazon.randomcutforest.parkservices.RCFCaster; public class ForecastCheckpointReadWorker extends - CheckpointReadWorker { + CheckpointReadWorker { public static final String WORKER_NAME = "forecast-checkpoint-read"; public ForecastCheckpointReadWorker( @@ -59,12 +58,10 @@ public ForecastCheckpointReadWorker( ForecastCheckpointDao checkpointDao, ForecastColdStartWorker entityColdStartQueue, NodeStateManager stateManager, - ForecastIndexManagement indexUtil, Provider cacheProvider, Duration stateTtl, ForecastCheckpointWriteWorker checkpointWriteQueue, - ForecastStats forecastStats, - ForecastSaveResultStrategy saveResultStrategy + ForecastInferencer inferencer ) { super( WORKER_NAME, @@ -87,17 +84,14 @@ public ForecastCheckpointReadWorker( checkpointDao, entityColdStartQueue, stateManager, - indexUtil, cacheProvider, stateTtl, checkpointWriteQueue, - forecastStats, FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, - StatNames.FORECAST_MODEL_CORRUTPION_COUNT, AnalysisType.FORECAST, - saveResultStrategy + inferencer ); } } diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java index 43831f8df..dcc1dca6f 100644 --- a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java @@ -20,6 +20,7 @@ import org.opensearch.forecast.indices.ForecastIndexManagement; import org.opensearch.forecast.ml.ForecastCheckpointDao; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastInferencer; import org.opensearch.forecast.ml.ForecastModelManager; import org.opensearch.forecast.ml.RCFCasterResult; import org.opensearch.forecast.model.ForecastResult; @@ -48,7 +49,7 @@ * */ public class ForecastColdEntityWorker extends - ColdEntityWorker { + ColdEntityWorker { public static final String WORKER_NAME = "forecast-cold-entity"; public ForecastColdEntityWorker( diff --git a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java index 92bbf9325..15e30ef76 100644 --- a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java +++ b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java @@ -226,7 +226,7 @@ protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { } @Override - protected String getDuplicateConfigErrorMsg(String name) { + public String getDuplicateConfigErrorMsg(String name) { return String.format(Locale.ROOT, DUPLICATE_FORECASTER_MSG, name); } diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java index d638b3bae..9d58ec049 100644 --- a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java @@ -25,6 +25,7 @@ import org.opensearch.forecast.indices.ForecastIndexManagement; import org.opensearch.forecast.ml.ForecastCheckpointDao; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastInferencer; import org.opensearch.forecast.ml.ForecastModelManager; import org.opensearch.forecast.ml.RCFCasterResult; import org.opensearch.forecast.model.ForecastResult; @@ -34,7 +35,6 @@ import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; -import org.opensearch.forecast.stats.ForecastStats; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; @@ -44,7 +44,6 @@ import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.transport.EntityResultProcessor; import org.opensearch.timeseries.transport.EntityResultRequest; import org.opensearch.timeseries.util.ExceptionUtil; @@ -78,21 +77,17 @@ public class EntityForecastResultTransportAction extends HandledTransportAction< private CacheProvider cache; private final NodeStateManager stateManager; private ThreadPool threadPool; - private EntityResultProcessor intervalDataProcessor; + private EntityResultProcessor intervalDataProcessor; private final ForecastCacheProvider entityCache; - private final ForecastModelManager manager; - private final ForecastStats timeSeriesStats; - private final ForecastColdStartWorker entityColdStartWorker; private final ForecastCheckpointReadWorker checkpointReadQueue; private final ForecastColdEntityWorker coldEntityQueue; - private final ForecastSaveResultStrategy forecastSaveResultStategy; + private final ForecastInferencer inferencer; @Inject public EntityForecastResultTransportAction( ActionFilters actionFilters, TransportService transportService, - ForecastModelManager manager, CircuitBreakerService adCircuitBreakerService, ForecastCacheProvider entityCache, NodeStateManager stateManager, @@ -101,9 +96,7 @@ public EntityForecastResultTransportAction( ForecastCheckpointReadWorker checkpointReadQueue, ForecastColdEntityWorker coldEntityQueue, ThreadPool threadPool, - ForecastColdStartWorker entityColdStartWorker, - ForecastStats timeSeriesStats, - ForecastSaveResultStrategy forecastSaveResultStategy + ForecastInferencer inferencer ) { super(EntityForecastResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); this.circuitBreakerService = adCircuitBreakerService; @@ -112,12 +105,9 @@ public EntityForecastResultTransportAction( this.threadPool = threadPool; this.intervalDataProcessor = null; this.entityCache = entityCache; - this.manager = manager; - this.timeSeriesStats = timeSeriesStats; - this.entityColdStartWorker = entityColdStartWorker; this.checkpointReadQueue = checkpointReadQueue; this.coldEntityQueue = coldEntityQueue; - this.forecastSaveResultStategy = forecastSaveResultStategy; + this.inferencer = inferencer; } @Override @@ -151,13 +141,11 @@ protected void doExecute(Task task, EntityResultRequest request, ActionListener< intervalDataProcessor = new EntityResultProcessor<>( entityCache, - manager, - timeSeriesStats, - entityColdStartWorker, checkpointReadQueue, coldEntityQueue, - forecastSaveResultStategy, - StatNames.FORECAST_MODEL_CORRUTPION_COUNT + inferencer, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME ); stateManager diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastImputeMissingValueAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastImputeMissingValueAction.java new file mode 100644 index 000000000..8e6c59aaa --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastImputeMissingValueAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastImputeMissingValueAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "impute"; + public static final ForecastImputeMissingValueAction INSTANCE = new ForecastImputeMissingValueAction(); + + public ForecastImputeMissingValueAction() { + super(NAME, AcknowledgedResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java index 8489bf47f..bc8241e5a 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java @@ -46,7 +46,6 @@ public class ForecastResultProcessor extends public ForecastResultProcessor( Setting requestTimeoutSetting, - float intervalRatioForRequests, String entityResultAction, StatNames hcRequestCountStat, Settings settings, @@ -68,7 +67,6 @@ public ForecastResultProcessor( ) { super( requestTimeoutSetting, - intervalRatioForRequests, entityResultAction, hcRequestCountStat, settings, @@ -106,4 +104,8 @@ protected ForecastResultResponse createResultResponse( return new ForecastResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); } + @Override + protected void imputeHC(long dataStartTime, long dataEndTime, String configID, String taskId) { + // no imputation for forecasting as on the fly imputation and error estimation should not mix + } } diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java index 1db61e5d9..ba5dc64ea 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java @@ -43,7 +43,6 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.feature.FeatureManager; -import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.timeseries.transport.ResultProcessor; @@ -149,7 +148,6 @@ protected void doExecute(Task task, ForecastResultRequest request, ActionListene this.resultProcessor = new ForecastResultProcessor( ForecastSettings.FORECAST_REQUEST_TIMEOUT, - TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, EntityForecastResultAction.NAME, StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, settings, diff --git a/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileNodeRequest.java similarity index 88% rename from src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java rename to src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileNodeRequest.java index 4c2895378..5c84d0668 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileNodeRequest.java @@ -3,13 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.timeseries.transport; +package org.opensearch.forecast.transport; import java.io.IOException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.forecast.transport.ForecastRunOnceProfileRequest; import org.opensearch.transport.TransportRequest; public class ForecastRunOnceProfileNodeRequest extends TransportRequest { diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java index a9fe218a8..156a34ff6 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java @@ -20,7 +20,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.transport.BooleanNodeResponse; import org.opensearch.timeseries.transport.BooleanResponse; -import org.opensearch.timeseries.transport.ForecastRunOnceProfileNodeRequest; import org.opensearch.transport.TransportService; public class ForecastRunOnceProfileTransportAction extends diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java index 49eb2b995..f157c4d6f 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java @@ -68,7 +68,6 @@ import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.model.TimeSeriesTask; -import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.timeseries.transport.ResultProcessor; @@ -303,7 +302,6 @@ private void triggerRunOnce(String forecastID, ForecastResultRequest request, Ac try { resultProcessor = new ForecastResultProcessor( ForecastSettings.FORECAST_REQUEST_TIMEOUT, - TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, EntityForecastResultAction.NAME, StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, settings, diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java index 3672a3e43..5a0aee36b 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java @@ -16,6 +16,7 @@ import org.opensearch.forecast.indices.ForecastIndexManagement; import org.opensearch.forecast.ml.ForecastCheckpointDao; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastInferencer; import org.opensearch.forecast.ml.ForecastModelManager; import org.opensearch.forecast.ml.RCFCasterResult; import org.opensearch.forecast.model.ForecastResult; @@ -24,23 +25,21 @@ import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; -import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; -import org.opensearch.forecast.stats.ForecastStats; -import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.ratelimit.RequestPriority; -import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.transport.AbstractSingleStreamResultTransportAction; import org.opensearch.transport.TransportService; import com.amazon.randomcutforest.parkservices.RCFCaster; public class ForecastSingleStreamResultTransportAction extends - AbstractSingleStreamResultTransportAction { + AbstractSingleStreamResultTransportAction { private static final Logger LOG = LogManager.getLogger(ForecastSingleStreamResultTransportAction.class); @@ -52,11 +51,8 @@ public ForecastSingleStreamResultTransportAction( ForecastCacheProvider cache, NodeStateManager stateManager, ForecastCheckpointReadWorker checkpointReadQueue, - ForecastModelManager modelManager, - ForecastIndexManagement indexUtil, - ForecastResultWriteWorker resultWriteQueue, - ForecastStats stats, - ForecastColdStartWorker forecastColdStartQueue + ForecastInferencer inferencer, + ThreadPool threadPool ) { super( transportService, @@ -65,15 +61,11 @@ public ForecastSingleStreamResultTransportAction( cache, stateManager, checkpointReadQueue, - modelManager, - indexUtil, - resultWriteQueue, - stats, - forecastColdStartQueue, ForecastSingleStreamResultAction.NAME, - ForecastIndex.RESULT, AnalysisType.FORECAST, - StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName() + inferencer, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME ); } diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java index df5ed807b..95ea64ef0 100644 --- a/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java @@ -47,6 +47,7 @@ public void bulk(ForecastResultBulkRequest currentBulkRequest, ActionListenerwrap(response -> { LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); listener.onResponse(response); diff --git a/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java index 102eb2f6d..69a6caaf2 100644 --- a/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java +++ b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java @@ -104,7 +104,6 @@ public void indexResult( ) { String configId = config.getId(); try { - if (!response.shouldSave()) { updateRealtimeTask(response, configId); return; @@ -115,7 +114,7 @@ public void indexResult( User user = config.getUser(); if (response.getError() != null) { - log.info("Result action run successfully for {} with error {}", configId, response.getError()); + log.info("Result action run for {} with error {}", configId, response.getError()); } List analysisResults = response diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index d935fd6ba..c1ce884a4 100644 --- a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -13,19 +13,11 @@ import static java.util.Collections.unmodifiableList; import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; -import static org.opensearch.ad.constant.ADCommonName.CHECKPOINT_INDEX_NAME; -import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; -import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; -import static org.opensearch.forecast.constant.ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME; -import static org.opensearch.forecast.constant.ForecastCommonName.FORECAST_STATE_INDEX; -import static org.opensearch.timeseries.constant.CommonName.CONFIG_INDEX; -import static org.opensearch.timeseries.constant.CommonName.JOB_INDEX; import java.security.AccessController; import java.security.PrivilegedAction; import java.time.Clock; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -54,6 +46,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; import org.opensearch.ad.model.AnomalyDetector; @@ -99,6 +92,8 @@ import org.opensearch.ad.transport.ADCancelTaskTransportAction; import org.opensearch.ad.transport.ADEntityProfileAction; import org.opensearch.ad.transport.ADEntityProfileTransportAction; +import org.opensearch.ad.transport.ADHCImputeAction; +import org.opensearch.ad.transport.ADHCImputeTransportAction; import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.ADProfileTransportAction; import org.opensearch.ad.transport.ADResultBulkAction; @@ -181,6 +176,7 @@ import org.opensearch.forecast.indices.ForecastIndexManagement; import org.opensearch.forecast.ml.ForecastCheckpointDao; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastInferencer; import org.opensearch.forecast.ml.ForecastModelManager; import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.Forecaster; @@ -257,7 +253,6 @@ import org.opensearch.forecast.transport.ValidateForecasterTransportAction; import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; import org.opensearch.forecast.transport.handler.ForecastSearchHandler; -import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.jobscheduler.spi.JobSchedulerExtension; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; @@ -266,7 +261,6 @@ import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ScriptPlugin; -import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; @@ -323,7 +317,7 @@ /** * Entry point of time series analytics plugin. */ -public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, SystemIndexPlugin, JobSchedulerExtension { +public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { private static final Logger LOG = LogManager.getLogger(TimeSeriesAnalyticsPlugin.class); @@ -817,7 +811,10 @@ public PooledObject wrap(LinkedBuffer obj) { StatNames.CONFIG_INDEX_STATUS.getName(), new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) ) - .put(StatNames.JOB_INDEX_STATUS.getName(), new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, JOB_INDEX))) + .put( + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) .put( StatNames.MODEL_COUNT.getName(), new TimeSeriesStat<>(false, new ADModelsOnNodeCountSupplier(adModelManager, adCacheProvider)) @@ -826,6 +823,16 @@ public PooledObject wrap(LinkedBuffer obj) { adStats = new ADStats(adStatsMap); + ADInferencer adInferencer = new ADInferencer( + adModelManager, + adStats, + adCheckpoint, + adColdstartQueue, + adSaveResultStrategy, + adCacheProvider, + threadPool + ); + ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker( heapSizeBytes, TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, @@ -845,12 +852,10 @@ public PooledObject wrap(LinkedBuffer obj) { adCheckpoint, adColdstartQueue, stateManager, - anomalyDetectionIndices, adCacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, adCheckpointWriteQueue, - adStats, - adSaveResultStrategy + adInferencer ); ADColdEntityWorker adColdEntityQueue = new ADColdEntityWorker( @@ -1199,12 +1204,25 @@ public PooledObject wrap(LinkedBuffer obj) { StatNames.CONFIG_INDEX_STATUS.getName(), new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) ) - .put(StatNames.JOB_INDEX_STATUS.getName(), new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, JOB_INDEX))) + .put( + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) .put(StatNames.MODEL_COUNT.getName(), new TimeSeriesStat<>(false, new ForecastModelsOnNodeCountSupplier(forecastCacheProvider))) .build(); forecastStats = new ForecastStats(forecastStatsMap); + ForecastInferencer forecastInferencer = new ForecastInferencer( + forecastModelManager, + forecastStats, + forecastCheckpoint, + forecastColdstartQueue, + forecastSaveResultStrategy, + forecastCacheProvider, + threadPool + ); + ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker( heapSizeBytes, TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, @@ -1224,12 +1242,10 @@ public PooledObject wrap(LinkedBuffer obj) { forecastCheckpoint, forecastColdstartQueue, stateManager, - forecastIndices, forecastCacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, forecastCheckpointWriteQueue, - forecastStats, - forecastSaveResultStrategy + forecastInferencer ); ForecastColdEntityWorker forecastColdEntityQueue = new ForecastColdEntityWorker( @@ -1350,6 +1366,7 @@ public PooledObject wrap(LinkedBuffer obj) { adIndexJobActionHandler, adSaveResultStrategy, new ADTaskProfileRunner(hashRing, client), + adInferencer, // forecast components forecastIndices, forecastStats, @@ -1368,12 +1385,13 @@ public PooledObject wrap(LinkedBuffer obj) { forecastIndexJobActionHandler, forecastTaskCacheManager, forecastSaveResultStrategy, - new ForecastTaskProfileRunner() + new ForecastTaskProfileRunner(), + forecastInferencer ); } /** - * createComponents doesn't work for Clock as ES process cannot start + * createComponents doesn't work for Clock as OS process cannot start * complaining it cannot find Clock instances for transport actions constructors. * @return a UTC clock */ @@ -1650,6 +1668,7 @@ public List getNamedXContent() { new ActionHandler<>(SearchTopAnomalyResultAction.INSTANCE, SearchTopAnomalyResultTransportAction.class), new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class), new ActionHandler<>(ADSingleStreamResultAction.INSTANCE, ADSingleStreamResultTransportAction.class), + new ActionHandler<>(ADHCImputeAction.INSTANCE, ADHCImputeTransportAction.class), // forecast new ActionHandler<>(IndexForecasterAction.INSTANCE, IndexForecasterTransportAction.class), new ActionHandler<>(ForecastResultAction.INSTANCE, ForecastResultTransportAction.class), @@ -1676,19 +1695,6 @@ public List getNamedXContent() { ); } - @Override - public Collection getSystemIndexDescriptors(Settings settings) { - List systemIndexDescriptors = new ArrayList<>(); - systemIndexDescriptors.add(new SystemIndexDescriptor(CONFIG_INDEX, "Time Series Analytics config index")); - systemIndexDescriptors.add(new SystemIndexDescriptor(ALL_AD_RESULTS_INDEX_PATTERN, "AD result index pattern")); - systemIndexDescriptors.add(new SystemIndexDescriptor(CHECKPOINT_INDEX_NAME, "AD Checkpoints index")); - systemIndexDescriptors.add(new SystemIndexDescriptor(DETECTION_STATE_INDEX, "AD State index")); - systemIndexDescriptors.add(new SystemIndexDescriptor(FORECAST_CHECKPOINT_INDEX_NAME, "Forecast Checkpoints index")); - systemIndexDescriptors.add(new SystemIndexDescriptor(FORECAST_STATE_INDEX, "Forecast state index")); - systemIndexDescriptors.add(new SystemIndexDescriptor(JOB_INDEX, "Time Series Analytics job index")); - return systemIndexDescriptors; - } - @Override public String getJobType() { return TIME_SERIES_JOB_TYPE; @@ -1696,7 +1702,7 @@ public String getJobType() { @Override public String getJobIndex() { - return JOB_INDEX; + return CommonName.JOB_INDEX; } @Override diff --git a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 6f4f9dcfd..043c197cf 100644 --- a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -761,6 +761,21 @@ public List> getAllModels() { return states; } + /** + * Gets a config's modelStates hosted on a node + * + * @return list of modelStates + */ + @Override + public List> getAllModels(String configId) { + List> states = new ArrayList<>(); + CacheBufferType cacheBuffer = activeEnities.get(configId); + if (cacheBuffer != null) { + states.addAll(cacheBuffer.getAllModelStates()); + } + return states; + } + /** * Gets all of a config's model sizes hosted on a node * diff --git a/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java index f1a94d588..9d335a719 100644 --- a/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java @@ -82,6 +82,13 @@ public interface TimeSeriesCache> getAllModels(); + /** + * Gets a config's modelStates hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(String configId); + /** * Get the number of active entities of a config * @param configId Config Id diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java index 330833a7e..8e0a7a537 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java @@ -43,7 +43,7 @@ public static String getTooManyCategoricalFieldErr(int limit) { public static String FAIL_TO_FIND_CONFIG_MSG = "Can't find config with id: "; public static final String CAN_NOT_CHANGE_CATEGORY_FIELD = "Can't change category field"; public static final String CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX = "Can't change custom result index"; - public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; + public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "Categorical field %s must be of type keyword or ip."; // Modifying message for FEATURE below may break the parseADValidationException method of ValidateAnomalyDetectorTransportAction public static final String FEATURE_INVALID_MSG_PREFIX = "Feature has an invalid query"; public static final String FEATURE_WITH_EMPTY_DATA_MSG = FEATURE_INVALID_MSG_PREFIX + " returning empty aggregated data: "; @@ -73,6 +73,7 @@ public static String getTooManyCategoricalFieldErr(int limit) { + TimeSeriesSettings.MAX_DESCRIPTION_LENGTH + " characters."; public static final String INDEX_NOT_FOUND = "index does not exist"; + public static final String FAIL_TO_GET_MAPPING_MSG = "Fail to get the index mapping of %s"; // ====================================== // Index message diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java index 90494862c..564aea2bf 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java @@ -18,8 +18,4 @@ public enum ImputationMethod { * This method replaces missing values with the last known value in the respective input dimension. It's a commonly used method for time series data, where temporal continuity is expected. */ PREVIOUS, - /** - * This method estimates missing values by interpolating linearly between known values in the respective input dimension. This method assumes that the data follows a linear trend. - */ - LINEAR } diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java index b163662a4..7147c753c 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java @@ -8,12 +8,10 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.HashMap; import java.util.Locale; +import java.util.Map; import java.util.Objects; -import java.util.Optional; import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.core.common.io.stream.StreamInput; @@ -22,53 +20,49 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.model.DataByFeatureId; +import org.opensearch.timeseries.model.Feature; public class ImputationOption implements Writeable, ToXContent { // field name in toXContent public static final String METHOD_FIELD = "method"; public static final String DEFAULT_FILL_FIELD = "defaultFill"; - public static final String INTEGER_SENSITIVE_FIELD = "integerSensitive"; private final ImputationMethod method; - private final Optional defaultFill; - private final boolean integerSentive; + private final Map defaultFill; - public ImputationOption(ImputationMethod method, Optional defaultFill, boolean integerSentive) { + public ImputationOption(ImputationMethod method, Map defaultFill) { this.method = method; this.defaultFill = defaultFill; - this.integerSentive = integerSentive; } public ImputationOption(ImputationMethod method) { - this(method, Optional.empty(), false); + this(method, null); } public ImputationOption(StreamInput in) throws IOException { this.method = in.readEnum(ImputationMethod.class); if (in.readBoolean()) { - this.defaultFill = Optional.of(in.readDoubleArray()); + this.defaultFill = in.readMap(StreamInput::readString, StreamInput::readDouble); } else { - this.defaultFill = Optional.empty(); + this.defaultFill = null; } - this.integerSentive = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeEnum(method); - if (defaultFill.isEmpty()) { + if (defaultFill == null || defaultFill.isEmpty()) { out.writeBoolean(false); } else { out.writeBoolean(true); - out.writeDoubleArray(defaultFill.get()); + out.writeMap(defaultFill, StreamOutput::writeString, StreamOutput::writeDouble); } - out.writeBoolean(integerSentive); } public static ImputationOption parse(XContentParser parser) throws IOException { ImputationMethod method = ImputationMethod.ZERO; - List defaultFill = null; - Boolean integerSensitive = null; + Map defaultFill = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -79,24 +73,40 @@ public static ImputationOption parse(XContentParser parser) throws IOException { method = ImputationMethod.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; case DEFAULT_FILL_FIELD: + defaultFill = new HashMap<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - defaultFill = new ArrayList<>(); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - defaultFill.add(parser.doubleValue()); + while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) { + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + String featureName = null; + Double fillValue = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fillFieldName = parser.currentName(); + parser.nextToken(); + + switch (fillFieldName) { + case Feature.FEATURE_NAME_FIELD: + featureName = parser.text(); + break; + case DataByFeatureId.DATA_FIELD: + fillValue = parser.doubleValue(); + break; + default: + // the unknown field and it's children should be ignored + parser.skipChildren(); + break; + } + } + + defaultFill.put(featureName, fillValue); } break; - case INTEGER_SENSITIVE_FIELD: - integerSensitive = parser.booleanValue(); - break; default: break; } } - return new ImputationOption( - method, - Optional.ofNullable(defaultFill).map(list -> list.stream().mapToDouble(Double::doubleValue).toArray()), - integerSensitive - ); + return new ImputationOption(method, defaultFill); } public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @@ -109,10 +119,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(METHOD_FIELD, method); - if (!defaultFill.isEmpty()) { - builder.array(DEFAULT_FILL_FIELD, defaultFill.get()); + if (defaultFill != null && !defaultFill.isEmpty()) { + builder.startArray(DEFAULT_FILL_FIELD); + for (Map.Entry fill : defaultFill.entrySet()) { + builder.startObject(); + builder.field(Feature.FEATURE_NAME_FIELD, fill.getKey()); + builder.field(DataByFeatureId.DATA_FIELD, fill.getValue()); + builder.endObject(); + } + builder.endArray(); } - builder.field(INTEGER_SENSITIVE_FIELD, integerSentive); return xContentBuilder.endObject(); } @@ -126,34 +142,24 @@ public boolean equals(Object o) { } ImputationOption other = (ImputationOption) o; - return method == other.method - && (defaultFill.isEmpty() ? other.defaultFill.isEmpty() : Arrays.equals(defaultFill.get(), other.defaultFill.get())) - && integerSentive == other.integerSentive; + return method == other.method && Objects.equals(defaultFill, other.defaultFill); } @Override public int hashCode() { - return Objects.hash(method, (defaultFill.isEmpty() ? 0 : Arrays.hashCode(defaultFill.get())), integerSentive); + return Objects.hash(method, defaultFill); } @Override public String toString() { - return new ToStringBuilder(this) - .append("method", method) - .append("defaultFill", (defaultFill.isEmpty() ? null : Arrays.toString(defaultFill.get()))) - .append("integerSentive", integerSentive) - .toString(); + return new ToStringBuilder(this).append("method", method).append("defaultFill", defaultFill).toString(); } public ImputationMethod getMethod() { return method; } - public Optional getDefaultFill() { + public Map getDefaultFill() { return defaultFill; } - - public boolean isIntegerSentive() { - return integerSentive; - } } diff --git a/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java index 5f2609ed5..b5b747ae0 100644 --- a/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java @@ -88,11 +88,11 @@ would produce an InternalFilter (a subtype of InternalSingleBucketAggregation) w .orElseThrow(() -> new EndRunException("Failed to parse aggregation " + aggregation, true).countedInStats(false)); } - protected Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { - return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); + protected Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds, boolean keepMissingValue) { + return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds, keepMissingValue); } - protected Optional parseAggregations(Optional aggregations, List featureIds) { + protected Optional parseAggregations(Optional aggregations, List featureIds, boolean keepMissingValue) { return aggregations .map(aggs -> aggs.asMap()) .map( @@ -101,7 +101,16 @@ protected Optional parseAggregations(Optional aggregatio .mapToDouble(id -> Optional.ofNullable(map.get(id)).map(this::parseAggregation).orElse(Double.NaN)) .toArray() ) - .filter(result -> Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d))); + .flatMap(result -> { + if (keepMissingValue) { + // Convert Double.isInfinite values to Double.NaN + return Optional.of(Arrays.stream(result).map(d -> Double.isInfinite(d) ? Double.NaN : d).toArray()); + } else { + // Return the array only if it contains no Double.NaN or Double.isInfinite + boolean noneNaNOrInfinite = Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d)); + return noneNaNOrInfinite ? Optional.of(result) : Optional.empty(); + } + }); } protected void updateSourceAfterKey(Map afterKey, SearchSourceBuilder search) { diff --git a/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java index f4cae0c0e..767c6d2fe 100644 --- a/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java @@ -296,7 +296,7 @@ private Page analyzePage(SearchResponse response) { } */ for (Bucket bucket : composite.getBuckets()) { - Optional featureValues = parseBucket(bucket, config.getEnabledFeatureIds()); + Optional featureValues = parseBucket(bucket, config.getEnabledFeatureIds(), true); // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" if (featureValues.isPresent() && bucket.getKey() != null) { results.put(Entity.createEntityByReordering(bucket.getKey()), featureValues.get()); diff --git a/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java index 3de670ffb..5130b3d2b 100644 --- a/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java @@ -116,7 +116,7 @@ public void getCurrentFeatures( ) { List> missingRanges = Collections.singletonList(new SimpleImmutableEntry<>(startTime, endTime)); try { - searchFeatureDao.getFeatureSamplesForPeriods(config, missingRanges, context, ActionListener.wrap(points -> { + searchFeatureDao.getFeatureSamplesForPeriods(config, missingRanges, context, true, ActionListener.wrap(points -> { // we only have one point if (points.size() == 1) { Optional point = points.get(0); @@ -169,6 +169,7 @@ private void getColdStartSamples( config, sampleRanges, context, + false, new ThreadedActionListener<>( logger, threadPool, @@ -342,7 +343,7 @@ public void getPreviewFeatures(AnomalyDetector detector, long startMilli, long e int stride = sampleRangeResults.getValue(); int shingleSize = detector.getShingleSize(); - getSamplesForRanges(detector, sampleRanges, getFeatureSamplesListener(stride, shingleSize, listener)); + getSamplesForPreview(detector, sampleRanges, getFeatureSamplesListener(stride, shingleSize, listener)); } /** @@ -417,13 +418,13 @@ private ActionListener>> getSamplesRangesListener( * @param listener handle search results map: key is time ranges, value is corresponding search results * @throws IOException if a user gives wrong query input when defining a detector */ - void getSamplesForRanges( + void getSamplesForPreview( AnomalyDetector detector, List> sampleRanges, ActionListener>, double[][]>> listener ) throws IOException { searchFeatureDao - .getFeatureSamplesForPeriods(detector, sampleRanges, AnalysisType.AD, getSamplesRangesListener(sampleRanges, listener)); + .getFeatureSamplesForPeriods(detector, sampleRanges, AnalysisType.AD, false, getSamplesRangesListener(sampleRanges, listener)); } /** diff --git a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java index 89ad90926..8c865a337 100644 --- a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java @@ -51,7 +51,6 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.PipelineAggregatorBuilders; -import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; import org.opensearch.search.aggregations.bucket.composite.InternalComposite; import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; @@ -150,10 +149,10 @@ public SearchFeatureDao( } /** - * Returns to listener the epoch time of the latset data under the detector. + * Returns to listener the epoch time of the latest data under the detector. * * @param config info about the data - * @param listener onResponse is called with the epoch time of the latset data under the detector + * @param listener onResponse is called with the epoch time of the latest data under the detector */ public void getLatestDataTime(Config config, Optional entity, AnalysisType context, ActionListener> listener) { BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); @@ -525,7 +524,7 @@ private Optional parseMinDataTime(SearchResponse searchResponse) { public void getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); final ActionListener searchResponseListener = ActionListener - .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure); + .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds(), true)), listener::onFailure); // using the original context in listener as user roles have no permissions for internal operations like fetching a // checkpoint clientUtil @@ -551,7 +550,7 @@ public void getFeaturesForPeriodByBatch( SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); final ActionListener searchResponseListener = ActionListener.wrap(response -> { - listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); + listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds(), true)); }, listener::onFailure); // inject user role while searching. clientUtil @@ -565,31 +564,41 @@ public void getFeaturesForPeriodByBatch( ); } - private Map> parseBucketAggregationResponse(SearchResponse response, List featureIds) { + private Map> parseBucketAggregationResponse( + SearchResponse response, + List featureIds, + boolean keepMissingValue + ) { Map> dataPoints = new HashMap<>(); List aggregations = response.getAggregations().asList(); logger.debug("Feature aggregation result size {}", aggregations.size()); for (Aggregation agg : aggregations) { List buckets = ((InternalComposite) agg).getBuckets(); buckets.forEach(bucket -> { - Optional featureData = parseAggregations(Optional.ofNullable(bucket.getAggregations()), featureIds); + Optional featureData = parseAggregations( + Optional.ofNullable(bucket.getAggregations()), + featureIds, + keepMissingValue + ); dataPoints.put((Long) bucket.getKey().get(CommonName.DATE_HISTOGRAM), featureData); }); } return dataPoints; } - public Optional parseResponse(SearchResponse response, List featureIds) { - return parseAggregations(Optional.ofNullable(response).map(resp -> resp.getAggregations()), featureIds); + public Optional parseResponse(SearchResponse response, List featureIds, boolean keepMissingData) { + return parseAggregations(Optional.ofNullable(response).map(resp -> resp.getAggregations()), featureIds, keepMissingData); } /** - * Gets samples of features for the time ranges. + * Gets features for the time ranges. * - * Sampled features are not true features. They are intended to be approximate results produced at low costs. + * If called by preview API, sampled features are not true features. + * They are intended to be approximate results produced at low costs. * * @param config info about the indices, documents, feature query * @param ranges list of time ranges + * @param keepMissingValues whether to keep missing values or not in the result * @param listener handle approximate features for the time ranges * @throws IOException if a user gives wrong query input when defining a detector */ @@ -597,9 +606,10 @@ public void getFeatureSamplesForPeriods( Config config, List> ranges, AnalysisType context, + boolean keepMissingValues, ActionListener>> listener ) throws IOException { - SearchRequest request = createPreviewSearchRequest(config, ranges); + SearchRequest request = createRangeSearchRequest(config, ranges); final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { @@ -613,7 +623,7 @@ public void getFeatureSamplesForPeriods( .stream() .filter(InternalDateRange.class::isInstance) .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) - .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds())) + .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds(), keepMissingValues)) .collect(Collectors.toList()) ); }, listener::onFailure); @@ -640,9 +650,9 @@ private SearchRequest createFeatureSearchRequest(AnomalyDetector detector, long } } - private SearchRequest createPreviewSearchRequest(Config config, List> ranges) throws IOException { + private SearchRequest createRangeSearchRequest(Config config, List> ranges) throws IOException { try { - SearchSourceBuilder searchSourceBuilder = ParseUtils.generatePreviewQuery(config, ranges, xContent); + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateRangeQuery(config, ranges, xContent); return new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder); } catch (IOException e) { logger.warn("Failed to create feature search request for " + config.getId() + " for preview", e); @@ -752,7 +762,7 @@ public List> parseColdStartSampleResp(SearchResponse response .filter(bucket -> bucket.getFrom() != null && bucket.getFrom() instanceof ZonedDateTime) .filter(bucket -> bucket.getDocCount() > docCountThreshold) .sorted(Comparator.comparing((Bucket bucket) -> (ZonedDateTime) bucket.getFrom())) - .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds())) + .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds(), false)) .collect(Collectors.toList()); } @@ -801,7 +811,7 @@ public List parseColdStartSampleTimestamp(SearchResponse response, boolean .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) .filter(bucket -> bucket.getFrom() != null && bucket.getFrom() instanceof ZonedDateTime) .filter(bucket -> bucket.getDocCount() > docCountThreshold) - .filter(bucket -> parseBucket(bucket, config.getEnabledFeatureIds()).isPresent()) + .filter(bucket -> parseBucket(bucket, config.getEnabledFeatureIds(), false).isPresent()) .sorted(Comparator.comparing((Bucket bucket) -> (ZonedDateTime) bucket.getFrom())) .map(bucket -> ((ZonedDateTime) bucket.getFrom()).toInstant().toEpochMilli()) .collect(Collectors.toList()); @@ -851,11 +861,6 @@ public SearchRequest createColdStartFeatureSearchRequestForSingleFeature( } } - @Override - public Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { - return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); - } - /** * Get train samples within a time range. * diff --git a/src/main/java/org/opensearch/timeseries/ml/Inferencer.java b/src/main/java/org/opensearch/timeseries/ml/Inferencer.java new file mode 100644 index 000000000..ff7cdca3a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/Inferencer.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.ml; + +import java.util.Collections; +import java.util.Locale; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.TimeUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class Inferencer, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, SaveResultStrategyType extends SaveResultStrategy, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker> { + private static final Logger LOG = LogManager.getLogger(Inferencer.class); + protected ModelManagerType modelManager; + protected Stats stats; + private String modelCorruptionStat; + protected CheckpointDaoType checkpointDao; + protected ColdStartWorkerType coldStartWorker; + protected SaveResultStrategyType resultWriteWorker; + private CacheProvider cache; + private Map modelLocks = Collections.synchronizedMap(new WeakHashMap<>()); + private ThreadPool threadPool; + private String threadPoolName; + + public Inferencer( + ModelManagerType modelManager, + Stats stats, + String modelCorruptionStat, + CheckpointDaoType checkpointDao, + ColdStartWorkerType coldStartWorker, + SaveResultStrategyType resultWriteWorker, + CacheProvider cache, + ThreadPool threadPool, + String threadPoolName + ) { + this.modelManager = modelManager; + this.stats = stats; + this.modelCorruptionStat = modelCorruptionStat; + this.checkpointDao = checkpointDao; + this.coldStartWorker = coldStartWorker; + this.resultWriteWorker = resultWriteWorker; + this.cache = cache; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; + // WeakHashMap allows for automatic removal of entries when the key is no longer referenced elsewhere. + // This helps prevent memory leaks as the garbage collector can reclaim memory when modelId is no + // longer in use. + this.modelLocks = Collections.synchronizedMap(new WeakHashMap<>()); + } + + /** + * + * @param sample Sample to process + * @param modelState model state + * @param config Config accessor + * @param taskId task Id for batch analysis + * @return whether process succeeds or not + */ + public boolean process(Sample sample, ModelState modelState, Config config, String taskId) { + long expiryEpoch = TimeUtil.calculateTimeoutMillis(config, sample.getDataEndTime().toEpochMilli()); + return processWithTimeout(sample, modelState, config, taskId, expiryEpoch); + } + + private boolean processWithTimeout(Sample sample, ModelState modelState, Config config, String taskId, long expiryEpoch) { + String modelId = modelState.getModelId(); + ReentrantLock lock = (ReentrantLock) modelLocks.computeIfAbsent(modelId, k -> new ReentrantLock()); + + if (lock.tryLock()) { + try { + tryProcess(sample, modelState, config, taskId); + } finally { + if (lock.isHeldByCurrentThread()) { + lock.unlock(); + } + } + return true; + } else { + if (System.currentTimeMillis() >= expiryEpoch) { + LOG.warn("Timeout reached, not retrying."); + } else { + // Schedule a retry in one second + threadPool + .schedule( + () -> processWithTimeout(sample, modelState, config, taskId, expiryEpoch), + new TimeValue(1, TimeUnit.SECONDS), + threadPoolName + ); + } + + return false; + } + } + + private boolean tryProcess(Sample sample, ModelState modelState, Config config, String taskId) { + String modelId = modelState.getModelId(); + try { + RCFResultType result = modelManager.getResult(sample, modelState, modelId, config, taskId); + resultWriteWorker + .saveResult( + result, + config, + sample.getDataStartTime(), + sample.getDataEndTime(), + modelId, + sample.getValueList(), + modelState.getEntity(), + taskId + ); + } catch (IllegalArgumentException e) { + if (e.getMessage() != null && e.getMessage().contains("incorrect ordering of time")) { + // ignore current timestamp. + LOG + .warn( + String + .format( + Locale.ROOT, + "incorrect ordering of time for config %s model %s at data end time %d", + config.getId(), + modelState.getModelId(), + sample.getDataEndTime().toEpochMilli() + ) + ); + } else { + reColdStart(config, modelId, e, sample, taskId); + } + return false; + } catch (Exception e) { + // e.g., null pointer exception when there is a bug in RCF + reColdStart(config, modelId, e, sample, taskId); + } + return true; + } + + private void reColdStart(Config config, String modelId, Exception e, Sample sample, String taskId) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(modelCorruptionStat).increment(); + cache.get().removeModel(config.getId(), modelId); + if (null != modelId) { + checkpointDao + .deleteModelCheckpoint( + modelId, + ActionListener + .wrap( + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", modelId)), + ex -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", modelId), ex) + ) + ); + } + + coldStartWorker + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + RequestPriority.MEDIUM, + modelId, + sample.getValueList(), + sample.getDataStartTime().toEpochMilli(), + taskId + ) + ); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java index 3ccd00428..2cb0f0b17 100644 --- a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java @@ -463,8 +463,10 @@ private void getFeatures( // make sure the following logic making sense via checking lastRoundFirstStartTime > 0 if (lastRounddataSample != null && lastRounddataSample.size() > 0) { concatenatedDataSample = new ArrayList<>(); - concatenatedDataSample.addAll(lastRounddataSample); + // since we move farther in history in current one, last round data should come + // after current round data to keep time in sequence. concatenatedDataSample.addAll(samples); + concatenatedDataSample.addAll(lastRounddataSample); } else { concatenatedDataSample = samples; } @@ -525,12 +527,17 @@ public static > T applyImputatio case ZERO: return builder.imputationMethod(ImputationMethod.ZERO); case FIXED_VALUES: - // we did validate default fill is not empty and size matches enabled feature number in Config's constructor - return builder.imputationMethod(ImputationMethod.FIXED_VALUES).fillValues(imputationOption.getDefaultFill().get()); + // we did validate default fill is not empty, size matches enabled feature number in Config's constructor, + // and feature names matches existing features + List enabledFeatureName = config.getEnabledFeatureNames(); + double[] fillValues = new double[enabledFeatureName.size()]; + Map defaultFillMap = imputationOption.getDefaultFill(); + for (int i = 0; i < enabledFeatureName.size(); i++) { + fillValues[i] = defaultFillMap.get(enabledFeatureName.get(i)); + } + return builder.imputationMethod(ImputationMethod.FIXED_VALUES).fillValues(fillValues); case PREVIOUS: return builder.imputationMethod(ImputationMethod.PREVIOUS); - case LINEAR: - return builder.imputationMethod(ImputationMethod.LINEAR); default: // by default using last known value return builder.imputationMethod(ImputationMethod.PREVIOUS); diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java index 273a580d9..d2e557be3 100644 --- a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java @@ -28,6 +28,7 @@ import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.util.DataUtil; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; @@ -130,15 +131,12 @@ protected void clearModelForIterator(String detectorId, Map models, I } } - @SuppressWarnings("unchecked") public IntermediateResultType score( Sample sample, String modelId, ModelState modelState, Config config ) { - - IntermediateResultType result = createEmptyResult(); Optional model = modelState.getModel(); try { if (model != null && model.isPresent()) { @@ -147,23 +145,22 @@ public IntermediateResultType score( if (!modelState.getSamples().isEmpty()) { for (Sample unProcessedSample : modelState.getSamples()) { // we are sure that the process method will indeed return an instance of RCFDescriptor. - rcfModel.process(unProcessedSample.getValueList(), unProcessedSample.getDataEndTime().getEpochSecond()); + double[] unProcessedPoint = unProcessedSample.getValueList(); + int[] missingIndices = DataUtil.generateMissingIndicesArray(unProcessedPoint); + rcfModel.process(unProcessedPoint, unProcessedSample.getDataEndTime().getEpochSecond(), missingIndices); } modelState.clearSamples(); } - RCFDescriptor lastResult = (RCFDescriptor) rcfModel - .process(sample.getValueList(), sample.getDataEndTime().getEpochSecond()); - if (lastResult != null) { - result = toResult(rcfModel.getForest(), lastResult); - } + return score(sample, config, rcfModel); } } catch (Exception e) { LOG .error( new ParameterizedMessage( - "Fail to score for [{}]: model Id [{}], feature [{}]", + "Fail to score for [{}] at [{}]: model Id [{}], feature [{}]", modelState.getEntity().isEmpty() ? modelState.getConfigId() : modelState.getEntity().get(), + sample.getDataEndTime().getEpochSecond(), modelId, Arrays.toString(sample.getValueList()) ), @@ -173,13 +170,28 @@ public IntermediateResultType score( } finally { modelState.setLastUsedTime(clock.instant()); } - return result; + return createEmptyResult(); + } + + @SuppressWarnings("unchecked") + public IntermediateResultType score(Sample sample, Config config, RCFModelType rcfModel) { + double[] point = sample.getValueList(); + + int[] missingValues = DataUtil.generateMissingIndicesArray(point); + RCFDescriptor lastResult = (RCFDescriptor) rcfModel.process(point, sample.getDataEndTime().getEpochSecond(), missingValues); + if (lastResult != null) { + return toResult(rcfModel.getForest(), lastResult, point, missingValues != null, config); + } + return createEmptyResult(); } protected abstract IntermediateResultType createEmptyResult(); protected abstract IntermediateResultType toResult( RandomCutForest forecast, - RCFDescriptor castDescriptor + RCFDescriptor castDescriptor, + double[] point, + boolean featureImputed, + Config config ); } diff --git a/src/main/java/org/opensearch/timeseries/ml/Sample.java b/src/main/java/org/opensearch/timeseries/ml/Sample.java index bc1212596..0089206ee 100644 --- a/src/main/java/org/opensearch/timeseries/ml/Sample.java +++ b/src/main/java/org/opensearch/timeseries/ml/Sample.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.annotation.Generated; @@ -31,7 +32,6 @@ public class Sample implements ToXContentObject { private final Instant dataEndTime; public Sample(double[] data, Instant dataStartTime, Instant dataEndTime) { - super(); this.data = data; this.dataStartTime = dataStartTime; this.dataEndTime = dataEndTime; @@ -82,6 +82,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws * Key: value_list, Value type: java.util.ArrayList * Item type: java.lang.Double * Value: 8840.0, Type: java.lang.Double + * Key: feature_imputed, Value type: java.util.ArrayList + * Item type: java.lang.Boolean + * Value: true, Type: java.lang.Boolean * @return a Sample. */ public static Sample extractSample(Map map) { @@ -102,7 +105,6 @@ public static Sample extractSample(Map map) { Instant dataEndTime = Instant.ofEpochMilli(dataEndTimeLong); Instant dataStartTime = Instant.ofEpochMilli(dataStartTimeLong); - // Create a new Sample object and return it return new Sample(data, dataStartTime, dataEndTime); } @@ -112,7 +114,11 @@ public boolean isInvalid() { @Override public String toString() { - return "Sample [data=" + Arrays.toString(data) + ", dataStartTime=" + dataStartTime + ", dataEndTime=" + dataEndTime + "]"; + return new ToStringBuilder(this) + .append("data", Arrays.toString(data)) + .append("dataStartTime", dataStartTime) + .append("dataEndTime", dataEndTime) + .toString(); } @Generated @@ -125,11 +131,6 @@ public boolean equals(Object o) { return false; } Sample sample = (Sample) o; - // a few fields not included: - // 1)didn't include uiMetadata since toXContent/parse will produce a map of map - // and cause the parsed one not equal to the original one. This can be confusing. - // 2)didn't include id, schemaVersion, and lastUpdateTime as we deemed equality based on contents. - // Including id fails tests like AnomalyDetectorExecutionInput.testParseAnomalyDetectorExecutionInput. return Arrays.equals(data, sample.data) && dataStartTime.truncatedTo(ChronoUnit.MILLIS).equals(sample.dataStartTime.truncatedTo(ChronoUnit.MILLIS)) && dataEndTime.truncatedTo(ChronoUnit.MILLIS).equals(sample.dataEndTime.truncatedTo(ChronoUnit.MILLIS)); @@ -138,6 +139,7 @@ public boolean equals(Object o) { @Generated @Override public int hashCode() { - return Objects.hashCode(data, dataStartTime.truncatedTo(ChronoUnit.MILLIS), dataEndTime.truncatedTo(ChronoUnit.MILLIS)); + return Objects + .hashCode(Arrays.hashCode(data), dataStartTime.truncatedTo(ChronoUnit.MILLIS), dataEndTime.truncatedTo(ChronoUnit.MILLIS)); } } diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java b/src/main/java/org/opensearch/timeseries/model/Config.java index 1dfdfb54c..b1a968cd1 100644 --- a/src/main/java/org/opensearch/timeseries/model/Config.java +++ b/src/main/java/org/opensearch/timeseries/model/Config.java @@ -190,31 +190,6 @@ protected Config( return; } - if (imputationOption != null && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES) { - Optional defaultFill = imputationOption.getDefaultFill(); - if (defaultFill.isEmpty()) { - issueType = ValidationIssueType.IMPUTATION; - errorMessage = "No given values for fixed value interpolation"; - return; - } - - // Calculate the number of enabled features - long expectedFeatures = features == null ? 0 : features.stream().filter(Feature::getEnabled).count(); - - // Check if the length of the defaultFill array matches the number of expected features - if (defaultFill.get().length != expectedFeatures) { - issueType = ValidationIssueType.IMPUTATION; - errorMessage = String - .format( - Locale.ROOT, - "Incorrect number of values to fill. Got: %d. Expected: %d.", - defaultFill.get().length, - expectedFeatures - ); - return; - } - } - if (recencyEmphasis != null && (recencyEmphasis <= 0)) { issueType = ValidationIssueType.RECENCY_EMPHASIS; errorMessage = "recency emphasis has to be a positive integer"; @@ -240,12 +215,50 @@ protected Config( return; } + if (imputationOption != null && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES) { + Map defaultFill = imputationOption.getDefaultFill(); + if (defaultFill.isEmpty()) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = "No given values for fixed value interpolation"; + return; + } + + // Calculate the number of enabled features + List enabledFeatures = features == null + ? null + : features.stream().filter(Feature::getEnabled).collect(Collectors.toList()); + + // Check if the length of the defaultFill array matches the number of expected features + if (enabledFeatures == null || defaultFill.size() != enabledFeatures.size()) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = String + .format( + Locale.ROOT, + "Incorrect number of values to fill. Got: %d. Expected: %d.", + defaultFill.size(), + enabledFeatures == null ? 0 : enabledFeatures.size() + ); + return; + } + + Map defaultFills = imputationOption.getDefaultFill(); + + for (int i = 0; i < enabledFeatures.size(); i++) { + if (!defaultFills.containsKey(enabledFeatures.get(i).getName())) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = String.format(Locale.ROOT, "Missing feature name: %s.", enabledFeatures.get(i).getName()); + return; + } + } + } + this.id = id; this.version = version; this.name = name; this.description = description; this.timeField = timeField; this.indices = indices; + // we validate empty or no enabled features when starting config (Read IndexJobActionHandler.validateConfig) this.featureAttributes = features == null ? ImmutableList.of() : ImmutableList.copyOf(features); this.filterQuery = filterQuery; this.interval = interval; @@ -733,29 +746,27 @@ public static List findRedundantNames(List features) { @Generated @Override public String toString() { - return super.toString() - + ", " - + new ToStringBuilder(this) - .append("name", name) - .append("description", description) - .append("timeField", timeField) - .append("indices", indices) - .append("featureAttributes", featureAttributes) - .append("filterQuery", filterQuery) - .append("interval", interval) - .append("windowDelay", windowDelay) - .append("shingleSize", shingleSize) - .append("categoryFields", categoryFields) - .append("schemaVersion", schemaVersion) - .append("user", user) - .append("customResultIndex", customResultIndexOrAlias) - .append("imputationOption", imputationOption) - .append("recencyEmphasis", recencyEmphasis) - .append("seasonIntervals", seasonIntervals) - .append("historyIntervals", historyIntervals) - .append("customResultIndexMinSize", customResultIndexMinSize) - .append("customResultIndexMinAge", customResultIndexMinAge) - .append("customResultIndexTTL", customResultIndexTTL) - .toString(); + return new ToStringBuilder(this) + .append("name", name) + .append("description", description) + .append("timeField", timeField) + .append("indices", indices) + .append("featureAttributes", featureAttributes) + .append("filterQuery", filterQuery) + .append("interval", interval) + .append("windowDelay", windowDelay) + .append("shingleSize", shingleSize) + .append("categoryFields", categoryFields) + .append("schemaVersion", schemaVersion) + .append("user", user) + .append("customResultIndex", customResultIndexOrAlias) + .append("imputationOption", imputationOption) + .append("recencyEmphasis", recencyEmphasis) + .append("seasonIntervals", seasonIntervals) + .append("historyIntervals", historyIntervals) + .append("customResultIndexMinSize", customResultIndexMinSize) + .append("customResultIndexMinAge", customResultIndexMinAge) + .append("customResultIndexTTL", customResultIndexTTL) + .toString(); } } diff --git a/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java index c74679214..4098fb8d4 100644 --- a/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java +++ b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java @@ -9,12 +9,14 @@ import java.io.IOException; +import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; import com.google.common.base.Objects; @@ -79,10 +81,12 @@ public static DataByFeatureId parse(XContentParser parser) throws IOException { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } DataByFeatureId that = (DataByFeatureId) o; return Objects.equal(getFeatureId(), that.getFeatureId()) && Objects.equal(getData(), that.getData()); } @@ -100,10 +104,19 @@ public Double getData() { return data; } + public void setData(Double data) { + this.data = data; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(featureId); out.writeDouble(data); } + @Generated + @Override + public String toString() { + return super.toString() + ", " + new ToStringBuilder(this).append("featureId", featureId).append("data", data).toString(); + } } diff --git a/src/main/java/org/opensearch/timeseries/model/Feature.java b/src/main/java/org/opensearch/timeseries/model/Feature.java index 045a6b96b..b3aec52a4 100644 --- a/src/main/java/org/opensearch/timeseries/model/Feature.java +++ b/src/main/java/org/opensearch/timeseries/model/Feature.java @@ -35,7 +35,7 @@ public class Feature implements Writeable, ToXContentObject { private static final String FEATURE_ID_FIELD = "feature_id"; - private static final String FEATURE_NAME_FIELD = "feature_name"; + public static final String FEATURE_NAME_FIELD = "feature_name"; private static final String FEATURE_ENABLED_FIELD = "feature_enabled"; private static final String AGGREGATION_QUERY = "aggregation_query"; @@ -135,10 +135,12 @@ public static Feature parse(XContentParser parser) throws IOException { @Generated @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } Feature feature = (Feature) o; return Objects.equal(getId(), feature.getId()) && Objects.equal(getName(), feature.getName()) diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java index 2485def8a..41f145b72 100644 --- a/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java @@ -44,6 +44,7 @@ import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.Inferencer; import org.opensearch.timeseries.ml.IntermediateResult; import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; @@ -51,13 +52,12 @@ import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.IndexableResult; -import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.ActionListenerExecutor; import org.opensearch.timeseries.util.ExceptionUtil; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; -public abstract class CheckpointReadWorker, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker> +public abstract class CheckpointReadWorker, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, InferencerType extends Inferencer> extends BatchWorker { private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); @@ -65,13 +65,10 @@ public abstract class CheckpointReadWorker> cacheProvider; protected final String checkpointIndexName; - protected final StatNames modelCorruptionStat; + protected final InferencerType inferencer; public CheckpointReadWorker( String workerName, @@ -94,17 +91,14 @@ public CheckpointReadWorker( CheckpointType checkpointDao, ColdStartWorkerType entityColdStartWorker, NodeStateManager stateManager, - IndexManagementType indexUtil, Provider> cacheProvider, Duration stateTtl, CheckpointWriteWorkerType checkpointWriteWorker, - Stats timeSeriesStats, Setting concurrencySetting, Setting batchSizeSetting, String checkpointIndexName, - StatNames modelCorruptionStat, AnalysisType context, - SaveResultStrategyType resultWriteWorker + InferencerType inferencer ) { super( workerName, @@ -133,13 +127,10 @@ public CheckpointReadWorker( this.modelManager = modelManager; this.checkpointDao = checkpointDao; this.coldStartWorker = entityColdStartWorker; - this.indexUtil = indexUtil; this.cacheProvider = cacheProvider; this.checkpointWriteWorker = checkpointWriteWorker; - this.timeSeriesStats = timeSeriesStats; this.checkpointIndexName = checkpointIndexName; - this.modelCorruptionStat = modelCorruptionStat; - this.resultWriteWorker = resultWriteWorker; + this.inferencer = inferencer; } @Override @@ -347,7 +338,7 @@ protected ActionListener> processIterationUsingConfig ModelState restoredModelState, String modelId ) { - return ActionListener.wrap(configOptional -> { + return ActionListenerExecutor.wrap(configOptional -> { if (configOptional.isEmpty()) { LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); @@ -356,51 +347,26 @@ protected ActionListener> processIterationUsingConfig Config config = configOptional.get(); - RCFResultType result = null; - try { - result = modelManager - .getResult( - new Sample( - origRequest.getCurrentFeature(), - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()) - ), - restoredModelState, - modelId, - config, - origRequest.getTaskId() - ); - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), e); - timeSeriesStats.getStat(modelCorruptionStat.getName()).increment(); - if (null != origRequest.getModelId()) { - String entityModelId = origRequest.getModelId(); - checkpointDao - .deleteModelCheckpoint( - entityModelId, - ActionListener - .wrap( - r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), - ex -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), ex) - ) - ); + boolean processed = inferencer + .process( + new Sample( + origRequest.getCurrentFeature(), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()) + ), + restoredModelState, + config, + origRequest.getTaskId() + ); + if (processed) { + // try to load to cache + boolean loaded = cacheProvider.get().hostIfPossible(config, restoredModelState); + + if (false == loaded) { + // not in memory. Maybe cold entities or some other entities + // have filled the slot while waiting for loading checkpoints. + checkpointWriteWorker.write(restoredModelState, true, RequestPriority.LOW); } - - coldStartWorker.put(origRequest); - processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); - return; - } - - resultWriteWorker.saveResult(result, config, origRequest, modelId); - - // try to load to cache - boolean loaded = cacheProvider.get().hostIfPossible(config, restoredModelState); - - if (false == loaded) { - // not in memory. Maybe cold entities or some other entities - // have filled the slot while waiting for loading checkpoints. - checkpointWriteWorker.write(restoredModelState, true, RequestPriority.LOW); } processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); @@ -408,6 +374,6 @@ protected ActionListener> processIterationUsingConfig LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelId, exception)); nodeStateManager.setException(configId, exception); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); - }); + }, threadPool.executor(threadPoolName)); } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java index 703360a3f..5e3f20196 100644 --- a/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java @@ -28,6 +28,7 @@ import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.Inferencer; import org.opensearch.timeseries.ml.IntermediateResult; import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; @@ -35,7 +36,7 @@ import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; -public class ColdEntityWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, RCFResultType extends IntermediateResult, ModelManagerType extends ModelManager, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, CheckpointReadWorkerType extends CheckpointReadWorker> +public class ColdEntityWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, RCFResultType extends IntermediateResult, ModelManagerType extends ModelManager, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, InferencerType extends Inferencer, CheckpointReadWorkerType extends CheckpointReadWorker> extends ScheduledWorker { public ColdEntityWorker( diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java index d37c60480..251512cff 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java @@ -411,7 +411,7 @@ private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, S // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type // in top N entities list. That's confusing. // So we decide to block updating detector category field. - // for forecasting, we will not show results after forecaster configuration change (excluding changes like description) + // for forecasting, we will not show results after forecaster configuration change // thus it is safe to allow updating everything. In the future, we might change AD to allow such behavior. if (!canUpdateEverything) { if (!ParseUtils.listEqualsWithoutConsideringOrder(existingConfig.getCategoryFields(), config.getCategoryFields())) { @@ -435,15 +435,15 @@ private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, S ); handler.confirmBatchRunning(id, batchTasks, confirmBatchRunningListener); - } catch (IOException e) { - String message = "Failed to parse anomaly detector " + id; + } catch (Exception e) { + String message = "Failed to parse config " + id; logger.error(message, e); listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); } } - protected void validateAgainstExistingHCConfig(String detectorId, boolean indexingDryRun, ActionListener listener) { + protected void validateAgainstExistingHCConfig(String configId, boolean indexingDryRun, ActionListener listener) { if (timeSeriesIndices.doesConfigIndexExist()) { QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(Config.CATEGORY_FIELD)); @@ -455,12 +455,12 @@ protected void validateAgainstExistingHCConfig(String detectorId, boolean indexi searchRequest, ActionListener .wrap( - response -> onSearchHCConfigResponse(response, detectorId, indexingDryRun, listener), + response -> onSearchHCConfigResponse(response, configId, indexingDryRun, listener), exception -> listener.onFailure(exception) ) ); } else { - validateCategoricalField(detectorId, indexingDryRun, listener); + validateCategoricalField(configId, indexingDryRun, listener); } } @@ -527,25 +527,14 @@ protected void onSearchHCConfigResponse(SearchResponse response, String detector } @SuppressWarnings("unchecked") - protected void validateCategoricalField(String detectorId, boolean indexingDryRun, ActionListener listener) { + protected void validateCategoricalField(String configId, boolean indexingDryRun, ActionListener listener) { List categoryField = config.getCategoryFields(); - if (categoryField == null) { - searchConfigInputIndices(detectorId, indexingDryRun, listener); - return; - } + // categoryField should have at least 1 element. Otherwise, we won't reach here. // we only support a certain number of categorical field // If there is more fields than required, Config's constructor - // throws validation exception before reaching this line - int maxCategoryFields = maxCategoricalFields; - if (categoryField.size() > maxCategoryFields) { - listener - .onFailure( - createValidationException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), ValidationIssueType.CATEGORY) - ); - return; - } + // throws validation exception before reaching here String categoryField0 = categoryField.get(0); @@ -585,10 +574,8 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu Map metadataMap = (Map) type; String typeName = (String) metadataMap.get(CommonName.TYPE); if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { - listener - .onFailure( - createValidationException(CATEGORICAL_FIELD_TYPE_ERR_MSG, ValidationIssueType.CATEGORY) - ); + String error = String.format(Locale.ROOT, CATEGORICAL_FIELD_TYPE_ERR_MSG, field2Metadata.getKey()); + listener.onFailure(createValidationException(error, ValidationIssueType.CATEGORY)); return; } } @@ -610,9 +597,9 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu return; } - searchConfigInputIndices(detectorId, indexingDryRun, listener); + searchConfigInputIndices(configId, indexingDryRun, listener); }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + String message = String.format(Locale.ROOT, CommonMessages.FAIL_TO_GET_MAPPING_MSG, config.getIndices()); logger.error(message, error); listener.onFailure(new IllegalArgumentException(message)); }); @@ -621,7 +608,7 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); } - protected void searchConfigInputIndices(String detectorId, boolean indexingDryRun, ActionListener listener) { + protected void searchConfigInputIndices(String configId, boolean indexingDryRun, ActionListener listener) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(QueryBuilders.matchAllQuery()) .size(0) @@ -631,7 +618,7 @@ protected void searchConfigInputIndices(String detectorId, boolean indexingDryRu ActionListener searchResponseListener = ActionListener .wrap( - searchResponse -> onSearchConfigInputIndicesResponse(searchResponse, detectorId, indexingDryRun, listener), + searchResponse -> onSearchConfigInputIndicesResponse(searchResponse, configId, indexingDryRun, listener), exception -> listener.onFailure(exception) ); @@ -640,7 +627,7 @@ protected void searchConfigInputIndices(String detectorId, boolean indexingDryRu protected void onSearchConfigInputIndicesResponse( SearchResponse response, - String detectorId, + String configId, boolean indexingDryRun, ActionListener listener ) throws IOException { @@ -653,7 +640,7 @@ protected void onSearchConfigInputIndicesResponse( } listener.onFailure(new IllegalArgumentException(errorMsg)); } else { - validateConfigFeatures(detectorId, indexingDryRun, listener); + validateConfigFeatures(configId, indexingDryRun, listener); } } @@ -724,7 +711,6 @@ protected void finishConfigValidationOrContinueToModelValidation(ActionListener< } } - @SuppressWarnings("unchecked") protected void indexConfig(String id, ActionListener listener) throws IOException { Config copiedConfig = copyConfig(user, config); IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) @@ -776,7 +762,7 @@ protected void onCreateMappingsResponse(CreateIndexResponse response, boolean in } } - protected String checkShardsFailure(IndexResponse response) { + public String checkShardsFailure(IndexResponse response) { StringBuilder failureReasons = new StringBuilder(); if (response.getShardInfo().getFailed() > 0) { for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { @@ -832,7 +818,7 @@ protected void validateConfigFeatures(String id, boolean indexingDryRun, ActionL ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(ssb); ActionListener searchResponseListener = ActionListener.wrap(response -> { - Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); + Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId()), false); if (aggFeatureResult.isPresent()) { multiFeatureQueriesResponseListener .onResponse( diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AggregationPrep.java b/src/main/java/org/opensearch/timeseries/rest/handler/AggregationPrep.java index 534e4b3b9..cafc3a021 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/AggregationPrep.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AggregationPrep.java @@ -72,6 +72,10 @@ public double getBucketHitRate(SearchResponse response, IntervalTimeConfiguratio } public double getHistorgramBucketHitRate(SearchResponse response) { + int numberOfSamples = getNumberOfSamples(); + if (numberOfSamples == 0) { + return 0; + } Histogram histogram = validateAndRetrieveHistogramAggregation(response); if (histogram == null || histogram.getBuckets() == null) { logger.warn("Empty histogram buckets"); @@ -80,7 +84,7 @@ public double getHistorgramBucketHitRate(SearchResponse response) { // getBuckets returns non-empty bucket (e.g., doc_count > 0) int bucketCount = histogram.getBuckets().size(); - return bucketCount / getNumberOfSamples(); + return bucketCount / numberOfSamples; } public List getTimestamps(SearchResponse response) { diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java index 20c090d4b..5a174fe59 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java @@ -136,13 +136,13 @@ private int increaseAndGetNewInterval(IntervalTimeConfiguration oldInterval) { * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough * We only need to execute the next query if the previous one led to data that is too sparse. */ - class IntervalRecommendationListener implements ActionListener { + public class IntervalRecommendationListener implements ActionListener { private final ActionListener intervalListener; IntervalTimeConfiguration currentIntervalToTry; private final long expirationEpochMs; private LongBounds currentTimeStampBounds; - IntervalRecommendationListener( + public IntervalRecommendationListener( ActionListener intervalListener, SearchSourceBuilder searchSourceBuilder, IntervalTimeConfiguration currentIntervalToTry, diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java index c99ea519a..86ee34e88 100644 --- a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java @@ -98,11 +98,6 @@ public class TimeSeriesSettings { // ====================================== public static final int MAX_BATCH_TASK_PIECE_SIZE = 10_000; - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - public static final float INTERVAL_RATIO_FOR_REQUESTS = 0.9f; - public static final Duration HOURLY_MAINTENANCE = Duration.ofHours(1); // Maximum number of deleted tasks can keep in cache. diff --git a/src/main/java/org/opensearch/timeseries/stats/StatNames.java b/src/main/java/org/opensearch/timeseries/stats/StatNames.java index 8ea32dffe..e06773015 100644 --- a/src/main/java/org/opensearch/timeseries/stats/StatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/StatNames.java @@ -20,8 +20,9 @@ */ public enum StatNames { // common stats - CONFIG_INDEX_STATUS("config_index_status", StatType.TIMESERIES), - JOB_INDEX_STATUS("job_index_status", StatType.TIMESERIES), + // keep the name the same for bwc + CONFIG_INDEX_STATUS("anomaly_detectors_index_status", StatType.TIMESERIES), + JOB_INDEX_STATUS("anomaly_detection_job_index_status", StatType.TIMESERIES), // AD stats AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count", StatType.AD), AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count", StatType.AD), @@ -31,7 +32,8 @@ public enum StatNames { SINGLE_STREAM_DETECTOR_COUNT("single_stream_detector_count", StatType.AD), HC_DETECTOR_COUNT("hc_detector_count", StatType.AD), ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status", StatType.AD), - AD_MODELS_CHECKPOINT_INDEX_STATUS("anomaly_models_checkpoint_index_status", StatType.AD), + // keep the name the same for bwc + AD_MODELS_CHECKPOINT_INDEX_STATUS("models_checkpoint_index_status", StatType.AD), ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status", StatType.AD), MODEL_INFORMATION("models", StatType.AD), AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count", StatType.AD), diff --git a/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java index b5f08785a..3cf3aa8dc 100644 --- a/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/AbstractSingleStreamResultTransportAction.java @@ -6,7 +6,6 @@ package org.opensearch.timeseries.transport; import java.time.Instant; -import java.util.List; import java.util.Optional; import org.apache.logging.log4j.LogManager; @@ -17,6 +16,7 @@ import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; @@ -30,6 +30,7 @@ import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.Inferencer; import org.opensearch.timeseries.ml.IntermediateResult; import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; @@ -44,31 +45,24 @@ import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.ratelimit.ResultWriteRequest; -import org.opensearch.timeseries.ratelimit.ResultWriteWorker; import org.opensearch.timeseries.ratelimit.SaveResultStrategy; -import org.opensearch.timeseries.stats.Stats; -import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.util.ActionListenerExecutor; import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportService; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; -public abstract class AbstractSingleStreamResultTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer, PriorityCacheType extends PriorityCache, CacheProviderType extends CacheProvider, ResultType extends IndexableResult, RCFResultType extends IntermediateResult, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, CheckpointReadWorkerType extends CheckpointReadWorker, ResultWriteRequestType extends ResultWriteRequest, BatchRequestType extends ResultBulkRequest, ResultHandlerType extends IndexMemoryPressureAwareResultHandler, ResultWriteWorkerType extends ResultWriteWorker> +public abstract class AbstractSingleStreamResultTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer, PriorityCacheType extends PriorityCache, CacheProviderType extends CacheProvider, ResultType extends IndexableResult, RCFResultType extends IntermediateResult, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, InferencerType extends Inferencer, CheckpointReadWorkerType extends CheckpointReadWorker, ResultWriteRequestType extends ResultWriteRequest> extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(AbstractSingleStreamResultTransportAction.class); protected CircuitBreakerService circuitBreakerService; protected CacheProviderType cache; protected final NodeStateManager stateManager; protected CheckpointReadWorkerType checkpointReadQueue; - protected ModelManagerType modelManager; - protected IndexManagementType indexUtil; - protected ResultWriteWorkerType resultWriteQueue; - protected Stats stats; - protected ColdStartWorkerType coldStartWorker; - protected IndexType resultIndex; protected AnalysisType analysisType; - private String modelCorrptionStat; + private InferencerType inferencer; + private ThreadPool threadPool; + private String threadPoolName; public AbstractSingleStreamResultTransportAction( TransportService transportService, @@ -77,29 +71,21 @@ public AbstractSingleStreamResultTransportAction( CacheProviderType cache, NodeStateManager stateManager, CheckpointReadWorkerType checkpointReadQueue, - ModelManagerType modelManager, - IndexManagementType indexUtil, - ResultWriteWorkerType resultWriteQueue, - Stats stats, - ColdStartWorkerType coldStartQueue, String resultAction, - IndexType resultIndex, AnalysisType analysisType, - String modelCorrptionStat + InferencerType inferencer, + ThreadPool threadPool, + String threadPoolName ) { super(resultAction, transportService, actionFilters, SingleStreamResultRequest::new); this.circuitBreakerService = circuitBreakerService; this.cache = cache; this.stateManager = stateManager; this.checkpointReadQueue = checkpointReadQueue; - this.modelManager = modelManager; - this.indexUtil = indexUtil; - this.resultWriteQueue = resultWriteQueue; - this.stats = stats; - this.coldStartWorker = coldStartQueue; - this.resultIndex = resultIndex; this.analysisType = analysisType; - this.modelCorrptionStat = modelCorrptionStat; + this.inferencer = inferencer; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; } @Override @@ -141,7 +127,7 @@ public ActionListener> onGetConfig( SingleStreamResultRequest request, Optional prevException ) { - return ActionListener.wrap(configOptional -> { + return ActionListenerExecutor.wrap(configOptional -> { if (!configOptional.isPresent()) { listener.onFailure(new EndRunException(configId, "Config " + configId + " is not available.", false)); return; @@ -149,8 +135,6 @@ public ActionListener> onGetConfig( Config config = configOptional.get(); - Instant executionStartTime = Instant.now(); - String modelId = request.getModelId(); double[] datapoint = request.getDataPoint(); ModelState modelState = cache.get().get(modelId, config); @@ -169,54 +153,13 @@ public ActionListener> onGetConfig( ) ); } else { - try { - RCFResultType result = modelManager - .getResult( - new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), - modelState, - modelId, - config, - request.getTaskId() - ); - // result.getRcfScore() = 0 means the model is not initialized - if (result.getRcfScore() > 0) { - List indexableResults = result - .toIndexableResults( - config, - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - executionStartTime, - Instant.now(), - ParseUtils.getFeatureData(datapoint, config), - Optional.empty(), - indexUtil.getSchemaVersion(resultIndex), - modelId, - null, - null - ); - - for (ResultType r : indexableResults) { - resultWriteQueue.put(createResultWriteRequest(config, r)); - } - } - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); - stats.getStat(modelCorrptionStat).increment(); - cache.get().removeModel(configId, modelId); - coldStartWorker - .put( - new FeatureRequest( - System.currentTimeMillis() + config.getIntervalInMilliseconds(), - configId, - RequestPriority.MEDIUM, - modelId, - datapoint, - request.getStart(), - request.getTaskId() - ) - ); - } + inferencer + .process( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + modelState, + config, + request.getTaskId() + ); } // respond back @@ -229,7 +172,7 @@ public ActionListener> onGetConfig( LOG .error( new ParameterizedMessage( - "fail to get entity's result for config [{}]: start: [{}], end: [{}]", + "fail to get single stream result for config [{}]: start: [{}], end: [{}]", configId, request.getStart(), request.getEnd() @@ -237,7 +180,7 @@ public ActionListener> onGetConfig( exception ); listener.onFailure(exception); - }); + }, threadPool.executor(threadPoolName)); } public abstract ResultWriteRequestType createResultWriteRequest(Config config, ResultType result); diff --git a/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java index 0e4176273..25b7a2170 100644 --- a/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java @@ -33,14 +33,12 @@ import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; public class CronTransportAction extends TransportNodesAction { private final Logger LOG = LogManager.getLogger(CronTransportAction.class); private NodeStateManager transportStateManager; private ADModelManager adModelManager; - private FeatureManager featureManager; private ADCacheProvider adCacheProvider; private ForecastCacheProvider forecastCacheProvider; private ADColdStart adEntityColdStarter; @@ -56,7 +54,6 @@ public CronTransportAction( ActionFilters actionFilters, NodeStateManager tarnsportStatemanager, ADModelManager adModelManager, - FeatureManager featureManager, ADCacheProvider adCacheProvider, ForecastCacheProvider forecastCacheProvider, ADColdStart adEntityColdStarter, @@ -77,7 +74,6 @@ public CronTransportAction( ); this.transportStateManager = tarnsportStatemanager; this.adModelManager = adModelManager; - this.featureManager = featureManager; this.adCacheProvider = adCacheProvider; this.forecastCacheProvider = forecastCacheProvider; this.adEntityColdStarter = adEntityColdStarter; diff --git a/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java index a0df0da92..484265b49 100644 --- a/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java @@ -19,6 +19,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.caching.CacheProvider; import org.opensearch.timeseries.caching.TimeSeriesCache; @@ -27,6 +28,7 @@ import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.Inferencer; import org.opensearch.timeseries.ml.IntermediateResult; import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; @@ -42,8 +44,7 @@ import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.ratelimit.SaveResultStrategy; -import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.ActionListenerExecutor; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -52,56 +53,50 @@ * (e.g., EntityForecastResultTransportAction) * */ -public class EntityResultProcessor, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, HCCheckpointReadWorkerType extends CheckpointReadWorker, ColdEntityWorkerType extends ColdEntityWorker> { +public class EntityResultProcessor, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, InferencerType extends Inferencer, HCCheckpointReadWorkerType extends CheckpointReadWorker, ColdEntityWorkerType extends ColdEntityWorker> { private static final Logger LOG = LogManager.getLogger(EntityResultProcessor.class); private CacheProvider cache; - private ModelManagerType modelManager; - private Stats stats; - private ColdStartWorkerType entityColdStartWorker; private HCCheckpointReadWorkerType checkpointReadQueue; private ColdEntityWorkerType coldEntityQueue; - private SaveResultStrategyType saveResultStrategy; - private StatNames modelCorruptionStat; + private InferencerType inferencer; + private ThreadPool threadPool; + private String threadPoolName; public EntityResultProcessor( CacheProvider cache, - ModelManagerType manager, - Stats stats, - ColdStartWorkerType entityColdStartWorker, HCCheckpointReadWorkerType checkpointReadQueue, ColdEntityWorkerType coldEntityQueue, - SaveResultStrategyType saveResultStrategy, - StatNames modelCorruptionStat + InferencerType inferencer, + ThreadPool threadPool, + String threadPoolName ) { this.cache = cache; - this.modelManager = manager; - this.stats = stats; - this.entityColdStartWorker = entityColdStartWorker; this.checkpointReadQueue = checkpointReadQueue; this.coldEntityQueue = coldEntityQueue; - this.saveResultStrategy = saveResultStrategy; - this.modelCorruptionStat = modelCorruptionStat; + this.inferencer = inferencer; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; } public ActionListener> onGetConfig( ActionListener listener, - String forecasterId, + String configId, EntityResultRequest request, Optional prevException, AnalysisType analysisType ) { - return ActionListener.wrap(configOptional -> { + return ActionListenerExecutor.wrap(configOptional -> { if (!configOptional.isPresent()) { - listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + listener.onFailure(new EndRunException(configId, "Config " + configId + " is not available.", false)); return; } Config config = configOptional.get(); if (request.getEntities() == null) { - listener.onFailure(new EndRunException(forecasterId, "Fail to get any entities from request.", false)); + listener.onFailure(new EndRunException(configId, "Fail to get any entities from request.", false)); return; } @@ -115,7 +110,7 @@ public ActionListener> onGetConfig( entity = Entity.createSingleAttributeEntity(config.getCategoryFields().get(0), attrValues.get(CommonName.EMPTY_FIELD)); } - Optional modelIdOptional = entity.getModelId(forecasterId); + Optional modelIdOptional = entity.getModelId(configId); if (modelIdOptional.isEmpty()) { continue; } @@ -128,51 +123,19 @@ public ActionListener> onGetConfig( cacheMissEntities.put(entity, datapoint); continue; } - try { - IntermediateResultType result = modelManager - .getResult( - new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), - entityModel, - modelId, - config, - request.getTaskId() - ); - - saveResultStrategy - .saveResult( - result, - config, - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - modelId, - datapoint, - Optional.of(entity), - request.getTaskId() - ); - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); - stats.getStat(modelCorruptionStat.getName()).increment(); - cache.get().removeModel(forecasterId, modelId); - entityColdStartWorker - .put( - new FeatureRequest( - System.currentTimeMillis() + config.getIntervalInMilliseconds(), - forecasterId, - RequestPriority.MEDIUM, - datapoint, - request.getStart(), - entity, - request.getTaskId() - ) - ); - } + inferencer + .process( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + entityModel, + config, + request.getTaskId() + ); } // split hot and cold entities Pair, List> hotColdEntities = cache .get() - .selectUpdateCandidate(cacheMissEntities.keySet(), forecasterId, config); + .selectUpdateCandidate(cacheMissEntities.keySet(), configId, config); List hotEntityRequests = new ArrayList<>(); List coldEntityRequests = new ArrayList<>(); @@ -187,7 +150,7 @@ public ActionListener> onGetConfig( .add( new FeatureRequest( System.currentTimeMillis() + config.getIntervalInMilliseconds(), - forecasterId, + configId, // hot entities has MEDIUM priority RequestPriority.MEDIUM, hotEntityValue, @@ -208,7 +171,7 @@ public ActionListener> onGetConfig( .add( new FeatureRequest( System.currentTimeMillis() + config.getIntervalInMilliseconds(), - forecasterId, + configId, // cold entities has LOW priority RequestPriority.LOW, coldEntityValue, @@ -232,14 +195,14 @@ public ActionListener> onGetConfig( .error( new ParameterizedMessage( "fail to get entity's analysis result for config [{}]: start: [{}], end: [{}]", - forecasterId, + configId, request.getStart(), request.getEnd() ), exception ); listener.onFailure(exception); - }); + }, threadPool.executor(threadPoolName)); } /** diff --git a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java index 98b56930f..ab00bfa2f 100644 --- a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java @@ -22,6 +22,7 @@ public class JobRequest extends ActionRequest { private String configID; + // data start/end time. See ADBatchTaskRunner.getDateRangeOfSourceData. private DateRange dateRange; private boolean historical; private String rawPath; diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java index ffee7717f..b244ee4ac 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -12,7 +12,9 @@ package org.opensearch.timeseries.transport; import java.net.ConnectException; +import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -20,6 +22,8 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -79,8 +83,10 @@ import org.opensearch.timeseries.stats.Stats; import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.DataUtil; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.timeseries.util.TimeUtil; import org.opensearch.transport.ActionNotFoundTransportException; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.NodeNotConnectedException; @@ -111,20 +117,16 @@ public abstract class ResultProcessor transportResultResponseClazz; private StatNames hcRequestCountStat; private String threadPoolName; - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - private final float intervalRatioForRequest; private int maxEntitiesPerInterval; private int pageSize; protected final ThreadPool threadPool; - private final HashRing hashRing; + protected final HashRing hashRing; protected final NodeStateManager nodeStateManager; protected final TransportService transportService; private final Stats timeSeriesStats; private final TaskManagerType realTimeTaskManager; private NamedXContentRegistry xContentRegistry; - private final Client client; + protected final Client client; private final SecurityClientUtil clientUtil; private Settings settings; private final IndexNameExpressionResolver indexNameExpressionResolver; @@ -137,7 +139,6 @@ public abstract class ResultProcessor requestTimeoutSetting, - float intervalRatioForRequests, String entityResultAction, StatNames hcRequestCountStat, Settings settings, @@ -166,8 +167,6 @@ public ResultProcessor( .withType(TransportRequestOptions.Type.REG) .withTimeout(requestTimeoutSetting.get(settings)) .build(); - this.intervalRatioForRequest = intervalRatioForRequests; - this.maxEntitiesPerInterval = maxEntitiesPerIntervalSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(maxEntitiesPerIntervalSetting, it -> maxEntitiesPerInterval = it); @@ -205,35 +204,40 @@ public ResultProcessor( class PageListener implements ActionListener { private PageIterator pageIterator; private String configId; + private Config config; private long dataStartTime; private long dataEndTime; - private Runnable finishRunnable; private String taskId; + private AtomicInteger receivedPages; + private AtomicInteger sentOutPages; - PageListener( - PageIterator pageIterator, - String detectorId, - long dataStartTime, - long dataEndTime, - Runnable finishRunnable, - String taskId - ) { + PageListener(PageIterator pageIterator, Config config, long dataStartTime, long dataEndTime, String taskId) { this.pageIterator = pageIterator; - this.configId = detectorId; + this.configId = config.getId(); + this.config = config; this.dataStartTime = dataStartTime; this.dataEndTime = dataEndTime; - this.finishRunnable = finishRunnable; this.taskId = taskId; + this.receivedPages = new AtomicInteger(); + this.sentOutPages = new AtomicInteger(); } @Override public void onResponse(CompositeRetriever.Page entityFeatures) { + // start processing next page after sending out features for previous page if (pageIterator.hasNext()) { pageIterator.next(this); - } else { - finishRunnable.run(); } if (entityFeatures != null && false == entityFeatures.isEmpty()) { + sentOutPages.incrementAndGet(); + + LOG + .info( + "Sending an HC request to process data from timestamp {} to {} for config {}", + dataStartTime, + dataEndTime, + configId + ); // wrap expensive operation inside ad threadpool threadPool.executor(threadPoolName).execute(() -> { try { @@ -269,7 +273,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { String .format( Locale.ROOT, - ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", + ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for config %s", modelNodeId, configId ) @@ -279,6 +283,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { } final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { DiscoveryNode node = nodeEntity.getKey(); transportService @@ -295,7 +300,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { ), option, new ActionListenerResponseHandler<>( - new ErrorResponseListener(node.getId(), configId, failure), + new ErrorResponseListener(node.getId(), configId, failure, receivedPages), AcknowledgedResponse::new, ThreadPool.Names.SAME ) @@ -308,17 +313,23 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { } }); } + + if (!pageIterator.hasNext() && config.getImputationOption() != null) { + if (sentOutPages.get() > 0) { + // at least 1 page sent out. Wait until all responses are back. + scheduleImputeHCTask(); + } else { + // no data in current interval. Send out impute request right away. + imputeHC(dataStartTime, dataEndTime, configId, taskId); + } + + } } @Override public void onFailure(Exception e) { - try { - LOG.error("Unexpetected exception", e); - handleException(e); - } finally { - // make sure we return listener - finishRunnable.run(); - } + LOG.error("Unexpetected exception", e); + handleException(e); } private void handleException(Exception e) { @@ -329,6 +340,47 @@ private void handleException(Exception e) { } nodeStateManager.setException(configId, convertedException); } + + /** + * Schedules imputeHC to after sent pages are equal to received pages at a fixed interval. + * + * We need to send impute request after ensuring it happens after all other entity feature requests. + * otherwise, we may rescore the same entity. + * + * If the condition is not met, it checks the condition regularly. The checker task is automatically + * canceled and the scheduler is shut down after a specified timeout period. + */ + private void scheduleImputeHCTask() { + AtomicReference cancellable = new AtomicReference<>(); + AtomicBoolean sent = new AtomicBoolean(); + + final Runnable checkerTask = new Runnable() { + private final long timeoutMillis = TimeUtil.calculateTimeoutMillis(config, dataEndTime); + + @Override + public void run() { + if (sentOutPages.get() == receivedPages.get()) { + if (!sent.get()) { + // since we don't know when cancel will succeed, need sent to ensure imputeHC is only called once + sent.set(true); + imputeHC(dataStartTime, dataEndTime, configId, taskId); + } + + if (cancellable.get() != null) { + cancellable.get().cancel(); + } + } else if (Instant.now().toEpochMilli() >= timeoutMillis) { + LOG.warn("Scheduled impute HC task is cancelled due to timeout"); + if (cancellable != null) { + cancellable.get().cancel(); + } + } + } + }; + + // Schedule the task at a 2 second interval + cancellable.set(threadPool.scheduleWithFixedDelay(checkerTask, TimeValue.timeValueSeconds(2), threadPoolName)); + } } public ActionListener> onGetConfig( @@ -437,7 +489,7 @@ private void executeAnalysis( } // assume request are in epoch milliseconds - long nextDetectionStartTime = request.getEnd() + (long) (config.getIntervalInMilliseconds() * intervalRatioForRequest); + long nextDetectionStartTime = request.getEnd() + config.getIntervalInMilliseconds(); CompositeRetriever compositeRetriever = new CompositeRetriever( dataStartTime, @@ -464,38 +516,23 @@ private void executeAnalysis( return; } - Runnable finishRunnable = () -> { - // When pagination finishes or the time is up, - // return response or exceptions. - if (previousException.isPresent()) { - listener.onFailure(previousException.get()); - } else { - listener - .onResponse( - createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) - ); - } - }; + PageListener getEntityFeatureslistener = new PageListener(pageIterator, config, dataStartTime, dataEndTime, taskId); - PageListener getEntityFeatureslistener = new PageListener( - pageIterator, - configID, - dataStartTime, - dataEndTime, - finishRunnable, - taskId - ); + // hasNext is always true unless time is up at this point (won't happen in normal cases) if (pageIterator.hasNext()) { - LOG - .info( - "Sending an HC request to process data from timestamp {} to {} for config {}", - dataStartTime, - dataEndTime, - configID - ); pageIterator.next(getEntityFeatureslistener); + } else if (config.getImputationOption() != null) { + imputeHC(dataStartTime, dataEndTime, configID, taskId); + } + + // return early to not wait for completion of all entities so we won't block next interval + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); } else { - finishRunnable.run(); + listener + .onResponse( + createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) + ); } return; @@ -513,6 +550,7 @@ private void executeAnalysis( DiscoveryNode rcfNode = asRCFNode.get(); + // early return listener in shouldStart if (!shouldStart(listener, configID, config, rcfNode.getId(), rcfModelID)) { return; } @@ -587,9 +625,9 @@ protected void findException(Throwable cause, String configID, AtomicReference actualException = NotSerializedExceptionName .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, configID); if (actualException.isPresent()) { - TimeSeriesException adException = actualException.get(); - failure.set(adException); - if (adException instanceof ResourceNotFoundException) { + TimeSeriesException tsException = actualException.get(); + failure.set(tsException); + if (tsException instanceof ResourceNotFoundException) { // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF // 1.0 // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node @@ -764,16 +802,19 @@ public class ErrorResponseListener implements ActionListener failure; + private AtomicInteger receivedPages; - public ErrorResponseListener(String nodeId, String configId, AtomicReference failure) { + public ErrorResponseListener(String nodeId, String configId, AtomicReference failure, AtomicInteger receivedPage) { this.nodeId = nodeId; this.configId = configId; this.failure = failure; + this.receivedPages = receivedPage; } @Override public void onResponse(AcknowledgedResponse response) { try { + receivedPages.incrementAndGet(); if (response.isAcknowledged() == false) { LOG.error("Cannot send entities' features to {} for {}", nodeId, configId); nodeStateManager.addPressure(nodeId, configId); @@ -789,6 +830,7 @@ public void onResponse(AcknowledgedResponse response) { @Override public void onFailure(Exception e) { try { + receivedPages.incrementAndGet(); // e.g., we have connection issues with all of the nodes while restarting clusters LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, configId), e); @@ -832,34 +874,68 @@ protected ActionListener> onFeatureResponseForSingleStreamCon } } - if (featureOptional.isEmpty()) { + if ((featureOptional.isEmpty() || DataUtil.areAnyElementsNaN(featureOptional.get())) && config.getImputationOption() == null) { // Feature not available is common when we have data holes. Respond empty response // and don't log to avoid bloating our logs. - LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, configId); listener - .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); + .onResponse( + createResultResponse( + new ArrayList(), + String + .format( + Locale.ROOT, + "No data in current window between %d and %d for %s", + dataStartTime, + dataEndTime, + configId + ), + null, + null, + false, + taskId + ) + ); return; } final AtomicReference failure = new AtomicReference(); - LOG - .info( - "Sending a single stream request to node {} to process data from timestamp {} to {} for config {}", - rcfNode.getId(), - dataStartTime, - dataEndTime, - configId - ); + double[] point = null; + if (featureOptional.isPresent()) { + point = featureOptional.get(); + } else { + int featureSize = config.getEnabledFeatureIds().size(); + point = new double[featureSize]; + Arrays.fill(point, Double.NaN); + } + if (DataUtil.areAnyElementsNaN(point)) { + LOG + .info( + "Sending a single stream request to node {} to impute/process data from timestamp {} to {} for config {}", + rcfNode.getId(), + dataStartTime, + dataEndTime, + configId + ); + } else { + LOG + .info( + "Sending a single stream request to node {} to process data from timestamp {} to {} for config {}", + rcfNode.getId(), + dataStartTime, + dataEndTime, + configId + ); + } transportService .sendRequest( rcfNode, singleStreamActionName, - new SingleStreamResultRequest(configId, rcfModelId, dataStartTime, dataEndTime, featureOptional.get(), taskId), + new SingleStreamResultRequest(configId, rcfModelId, dataStartTime, dataEndTime, point, taskId), option, new ActionListenerResponseHandler<>( - new ErrorResponseListener(rcfNode.getId(), configId, failure), + new ErrorResponseListener(rcfNode.getId(), configId, failure, new AtomicInteger()), AcknowledgedResponse::new, ThreadPool.Names.SAME ) @@ -867,18 +943,13 @@ protected ActionListener> onFeatureResponseForSingleStreamCon if (previousException.isPresent()) { listener.onFailure(previousException.get()); - } else if (featureOptional.isEmpty()) { - // Feature not available is common when we have data holes. Respond empty response - // and don't log to avoid bloating our logs. - LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, configId); - listener - .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); } else { listener .onResponse( createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) ); } + }, exception -> { handleQueryFailure(exception, listener, configId); }); } @@ -890,4 +961,6 @@ protected abstract ResultResponseType createResultResponse( Boolean isHC, String taskId ); + + protected abstract void imputeHC(long dataStartTime, long dataEndTime, String configID, String taskId); } diff --git a/src/main/java/org/opensearch/timeseries/util/ActionListenerExecutor.java b/src/main/java/org/opensearch/timeseries/util/ActionListenerExecutor.java new file mode 100644 index 000000000..b4cea8ebb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/ActionListenerExecutor.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +import org.opensearch.common.CheckedConsumer; +import org.opensearch.core.action.ActionListener; + +public class ActionListenerExecutor { + /* + * Private constructor to avoid Jacoco complaining about public constructor + * not covered: https://tinyurl.com/yetc7tra + */ + private ActionListenerExecutor() {} + + /** + * Wraps the provided response and failure handlers in an ActionListener that executes the + * response handler asynchronously using the provided ExecutorService. + * + * @param the type of the response + * @param onResponse a CheckedConsumer that handles the response; it can throw an exception + * @param onFailure a Consumer that handles any exceptions thrown by the onResponse handler or the onFailure method + * @param executorService the ExecutorService used to execute the onResponse handler asynchronously + * @return an ActionListener that handles the response and failure cases + */ + public static ActionListener wrap( + CheckedConsumer onResponse, + Consumer onFailure, + ExecutorService executorService + ) { + return new ActionListener() { + @Override + public void onResponse(Response response) { + executorService.execute(() -> { + try { + onResponse.accept(response); + } catch (Exception e) { + onFailure(e); + } + }); + } + + @Override + public void onFailure(Exception e) { + onFailure.accept(e); + } + }; + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/BulkUtil.java b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java index c2b275a1f..6b801da7c 100644 --- a/src/main/java/org/opensearch/timeseries/util/BulkUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java @@ -36,8 +36,11 @@ public static List getFailedIndexRequest(BulkRequest bulkRequest, Set failedId = new HashSet<>(); for (BulkItemResponse response : bulkResponse.getItems()) { - if (response.isFailed() && ExceptionUtil.isRetryAble(response.getFailure().getStatus())) { - failedId.add(response.getId()); + if (response.isFailed()) { + logger.info("bulk indexing failure: {}", response.getFailureMessage()); + if (ExceptionUtil.isRetryAble(response.getFailure().getStatus())) { + failedId.add(response.getId()); + } } } diff --git a/src/main/java/org/opensearch/timeseries/util/DataUtil.java b/src/main/java/org/opensearch/timeseries/util/DataUtil.java index 4f417e4f7..42dd16396 100644 --- a/src/main/java/org/opensearch/timeseries/util/DataUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/DataUtil.java @@ -5,7 +5,11 @@ package org.opensearch.timeseries.util; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; public class DataUtil { /** @@ -45,4 +49,43 @@ public static double[][] ltrim(double[][] arr) { return Arrays.copyOfRange(arr, startIndex, arr.length); } + public static int[] generateMissingIndicesArray(double[] point) { + List intArray = new ArrayList<>(); + for (int i = 0; i < point.length; i++) { + if (Double.isNaN(point[i])) { + intArray.add(i); + } + } + // Return null if the array is empty + if (intArray.size() == 0) { + return null; + } + return intArray.stream().mapToInt(Integer::intValue).toArray(); + } + + public static boolean areAnyElementsNaN(double[] array) { + return Arrays.stream(array).anyMatch(Double::isNaN); + } + + /** + * Rounds the given double value to the specified number of decimal places. + * + * This method uses BigDecimal for precise rounding. It rounds using the + * HALF_UP rounding mode, which means it rounds towards the "nearest neighbor" + * unless both neighbors are equidistant, in which case it rounds up. + * + * @param value the double value to be rounded + * @param places the number of decimal places to round to + * @return the rounded double value + * @throws IllegalArgumentException if the specified number of decimal places is negative + */ + public static double roundDouble(double value, int places) { + if (places < 0) { + throw new IllegalArgumentException(); + } + + BigDecimal bd = new BigDecimal(Double.toString(value)); + bd = bd.setScale(places, RoundingMode.HALF_UP); + return bd.doubleValue(); + } } diff --git a/src/main/java/org/opensearch/timeseries/util/ModelUtil.java b/src/main/java/org/opensearch/timeseries/util/ModelUtil.java new file mode 100644 index 000000000..21527e785 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/ModelUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ad.model.ImputedFeatureResult; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; + +public class ModelUtil { + public static ImputedFeatureResult calculateImputedFeatures( + AnomalyDescriptor anomalyDescriptor, + double[] point, + boolean isImputed, + Config config + ) { + int inputLength = anomalyDescriptor.getInputLength(); + boolean[] isFeatureImputed = null; + double[] actual = point; + + if (isImputed) { + actual = new double[inputLength]; + isFeatureImputed = new boolean[inputLength]; + + ImputationOption imputationOption = config.getImputationOption(); + if (imputationOption != null && imputationOption.getMethod() == ImputationMethod.ZERO) { + for (int i = 0; i < point.length; i++) { + if (Double.isNaN(point[i])) { + isFeatureImputed[i] = true; + actual[i] = 0; + } + } + } else if (imputationOption != null && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES) { + Map defaultFills = imputationOption.getDefaultFill(); + List enabledFeatureNames = config.getEnabledFeatureNames(); + for (int i = 0; i < point.length; i++) { + if (Double.isNaN(point[i])) { + isFeatureImputed[i] = true; + actual[i] = defaultFills.get(enabledFeatureNames.get(i)); + } + } + } else { + float[] rcfPoint = anomalyDescriptor.getRCFPoint(); + if (rcfPoint == null) { + return new ImputedFeatureResult(isFeatureImputed, actual); + } + float[] transformedInput = new float[inputLength]; + System.arraycopy(rcfPoint, rcfPoint.length - inputLength, transformedInput, 0, inputLength); + + double[] scale = anomalyDescriptor.getScale(); + double[] shift = anomalyDescriptor.getShift(); + + for (int i = 0; i < point.length; i++) { + if (Double.isNaN(point[i])) { + isFeatureImputed[i] = true; + actual[i] = (transformedInput[i] * scale[i]) + shift[i]; + } + } + } + } + + return new ImputedFeatureResult(isFeatureImputed, actual); + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java index 119e21dab..3b64be259 100644 --- a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java @@ -331,7 +331,7 @@ public static SearchSourceBuilder generateInternalFeatureQuery( return internalSearchSourceBuilder; } - public static SearchSourceBuilder generatePreviewQuery( + public static SearchSourceBuilder generateRangeQuery( Config config, List> ranges, NamedXContentRegistry xContentRegistry diff --git a/src/main/java/org/opensearch/timeseries/util/TimeUtil.java b/src/main/java/org/opensearch/timeseries/util/TimeUtil.java new file mode 100644 index 000000000..22ce700be --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/TimeUtil.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +public class TimeUtil { + public static long calculateTimeoutMillis(Config config, long dataEndTimeMillis) { + long windowDelayMillis = config.getWindowDelay() == null + ? 0 + : ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); + long nextExecutionEnd = dataEndTimeMillis + config.getIntervalInMilliseconds() + windowDelayMillis; + return nextExecutionEnd; + } +} diff --git a/src/main/resources/mappings/anomaly-checkpoint.json b/src/main/resources/mappings/anomaly-checkpoint.json index af485860a..3c9532d81 100644 --- a/src/main/resources/mappings/anomaly-checkpoint.json +++ b/src/main/resources/mappings/anomaly-checkpoint.json @@ -1,7 +1,7 @@ { "dynamic": true, "_meta": { - "schema_version": 4 + "schema_version": 5 }, "properties": { "detectorId": { @@ -57,6 +57,17 @@ "data_end_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "feature_imputed": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword" + }, + "imputed": { + "type": "boolean" + } + } } } } diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 3fad67ec2..105fad141 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 6 + "schema_version": 7 }, "properties": { "detector_id": { @@ -157,6 +157,17 @@ }, "threshold": { "type": "double" + }, + "feature_imputed": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword" + }, + "imputed": { + "type": "boolean" + } + } } } } diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json index c64a697e7..2dc4954c9 100644 --- a/src/main/resources/mappings/config.json +++ b/src/main/resources/mappings/config.json @@ -1,172 +1,234 @@ { - "dynamic": false, - "_meta": { - "schema_version": 5 - }, - "properties": { - "schema_version": { - "type": "integer" + "dynamic": false, + "_meta": { + "schema_version": 6 }, - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - } - } - }, - "description": { - "type": "text" - }, - "time_field": { - "type": "keyword" - }, - "indices": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - } - } - }, - "result_index": { - "type": "keyword" - }, - "filter_query": { - "type": "object", - "enabled": false - }, - "feature_attributes": { - "type": "nested", - "properties": { - "feature_id": { - "type": "keyword", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 + "properties": { + "schema_version": { + "type": "integer" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } } - } - }, - "feature_name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 + }, + "description": { + "type": "text" + }, + "time_field": { + "type": "keyword" + }, + "indices": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } } - } }, - "feature_enabled": { - "type": "boolean" + "result_index": { + "type": "keyword" }, - "aggregation_query": { - "type": "object", - "enabled": false - } - } - }, - "detection_interval": { - "properties": { - "period": { - "properties": { - "interval": { - "type": "integer" - }, - "unit": { - "type": "keyword" + "filter_query": { + "type": "object", + "enabled": false + }, + "feature_attributes": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "feature_name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "feature_enabled": { + "type": "boolean" + }, + "aggregation_query": { + "type": "object", + "enabled": false + } } - } - } - } - }, - "window_delay": { - "properties": { - "period": { - "properties": { - "interval": { - "type": "integer" - }, - "unit": { - "type": "keyword" + }, + "detection_interval": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } } - } - } - } - }, - "shingle_size": { - "type": "integer" - }, - "last_update_time": { - "type": "date", - "format": "strict_date_time||epoch_millis" - }, - "ui_metadata": { - "type": "object", - "enabled": false - }, - "user": { - "type": "nested", - "properties": { - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 + }, + "window_delay": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } } - } }, - "backend_roles": { - "type" : "text", - "fields" : { - "keyword" : { - "type" : "keyword" + "shingle_size": { + "type": "integer" + }, + "last_update_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "ui_metadata": { + "type": "object", + "enabled": false + }, + "user": { + "type": "nested", + "properties": { + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } } - } }, - "roles": { - "type" : "text", - "fields" : { - "keyword" : { - "type" : "keyword" + "category_field": { + "type": "keyword" + }, + "detector_type": { + "type": "keyword" + }, + "forecast_interval": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } } - } }, - "custom_attribute_names": { - "type" : "text", - "fields" : { - "keyword" : { - "type" : "keyword" + "horizon": { + "type": "integer" + }, + "imputation_option": { + "type": "nested", + "properties": { + "method": { + "type": "keyword" + }, + "defaultFill": { + "type": "nested", + "properties": { + "feature_name": { + "type": "keyword" + }, + "data": { + "type": "double" + } + } + } } - } - } - } - }, - "category_field": { - "type": "keyword" - }, - "detector_type": { - "type": "keyword" - }, - "forecast_interval": { - "properties": { - "period": { - "properties": { - "interval": { - "type": "integer" - }, - "unit": { - "type": "keyword" + }, + "suggested_seasonality": { + "type": "integer" + }, + "recency_emphasis": { + "type": "integer" + }, + "history": { + "type": "integer" + }, + "result_index_min_size": { + "type": "integer" + }, + "result_index_min_age": { + "type": "integer" + }, + "result_index_ttl": { + "type": "integer" + }, + "rules": { + "type": "nested", + "properties": { + "action": { + "type": "keyword" + }, + "conditions": { + "type": "nested", + "properties": { + "feature_name": { + "type": "keyword" + }, + "threshold_type": { + "type": "keyword" + }, + "operator": { + "type": "keyword" + }, + "value": { + "type": "double" + } + } + } } - } } - } - }, - "horizon": { - "type": "integer" } - } -} +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/AbstractForecasterActionHandlerTestCase.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/AbstractForecasterActionHandlerTestCase.java new file mode 100644 index 000000000..5bee84cba --- /dev/null +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/AbstractForecasterActionHandlerTestCase.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.action.admin.indices.mapping.get; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.util.Arrays; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.rest.RestRequest; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class AbstractForecasterActionHandlerTestCase extends AbstractTimeSeriesTest { + + protected ClusterService clusterService; + protected ActionListener channel; + protected TransportService transportService; + protected ForecastIndexManagement forecastISM; + protected String forecasterId; + protected Long seqNo; + protected Long primaryTerm; + protected Forecaster forecaster; + protected WriteRequest.RefreshPolicy refreshPolicy; + protected TimeValue requestTimeout; + protected Integer maxSingleStreamForecasters; + protected Integer maxHCForecasters; + protected Integer maxForecastFeatures; + protected Integer maxCategoricalFields; + protected Settings settings; + protected RestRequest.Method method; + protected TaskManager forecastTaskManager; + protected SearchFeatureDao searchFeatureDao; + protected Clock clock; + @Mock + protected Client clientMock; + @Mock + protected ThreadPool threadPool; + protected ThreadContext threadContext; + protected SecurityClientUtil clientUtil; + protected String categoricalField; + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.initMocks(this); + + settings = Settings.EMPTY; + + clusterService = mock(ClusterService.class); + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + channel = mock(ActionListener.class); + transportService = mock(TransportService.class); + + forecastISM = mock(ForecastIndexManagement.class); + when(forecastISM.doesConfigIndexExist()).thenReturn(true); + + forecasterId = "123"; + seqNo = 0L; + primaryTerm = 0L; + clock = mock(Clock.class); + + refreshPolicy = WriteRequest.RefreshPolicy.IMMEDIATE; + + categoricalField = "a"; + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setCategoryFields(Arrays.asList(categoricalField)) + .build(); + + requestTimeout = new TimeValue(1000L); + maxSingleStreamForecasters = 1000; + maxHCForecasters = 10; + maxForecastFeatures = 5; + maxCategoricalFields = 10; + method = RestRequest.Method.POST; + forecastTaskManager = mock(ForecastTaskManager.class); + searchFeatureDao = mock(SearchFeatureDao.class); + + threadContext = new ThreadContext(settings); + Mockito.doReturn(threadPool).when(clientMock).threadPool(); + Mockito.doReturn(threadContext).when(threadPool).getThreadContext(); + + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + } + +} diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index e1e86e1ec..bd1159047 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -12,8 +12,8 @@ package org.opensearch.action.admin.indices.mapping.get; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Locale; import java.util.concurrent.CountDownLatch; @@ -32,6 +33,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Ignore; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; import org.opensearch.action.get.GetAction; @@ -116,7 +118,12 @@ public void setUp() throws Exception { super.setUp(); settings = Settings.EMPTY; + clusterService = mock(ClusterService.class); + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + clientMock = spy(new NodeClient(settings, threadPool)); NodeStateManager nodeStateManager = mock(NodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); @@ -311,7 +318,8 @@ public void doE assertTrue("should throw eror", false); inProgressLatch.countDown(); }, e -> { - assertTrue(e.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + String error = String.format(Locale.ROOT, CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG, field); + assertTrue("actual: " + e.getMessage(), e.getMessage().contains(error)); inProgressLatch.countDown(); })); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); @@ -505,7 +513,8 @@ public void doE if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { assertTrue(e.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } else { - assertTrue(e.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + String error = String.format(Locale.ROOT, CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG, field); + assertTrue("actual: " + e.getMessage(), e.getMessage().contains(error)); } inProgressLatch.countDown(); })); @@ -799,4 +808,102 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx verify(clientMock, times(0)).search(any(SearchRequest.class), any()); verify(clientMock, times(1)).get(any(GetRequest.class), any()); } + + public void testUpdateDifferentCategoricalField() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(GetAction.INSTANCE)) { + // Serialize the object + AnomalyDetector clone = new AnomalyDetector( + detector.getId(), + detector.getVersion(), + detector.getName(), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + Instant.now(), + detector.getCategoryFields(), + detector.getUser(), + "opensearch-ad-plugin-result-blah", + detector.getImputationOption(), + detector.getRecencyEmphasis(), + detector.getSeasonIntervals(), + detector.getHistoryIntervals(), + null, + detector.getCustomResultIndexMinSize(), + detector.getCustomResultIndexMinAge(), + detector.getCustomResultIndexTTL() + ); + try { + listener.onResponse((Response) TestHelpers.createGetResponse(clone, clone.getId(), CommonName.CONFIG_INDEX)); + } catch (IOException e) { + LOG.error(e); + } + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.PUT; + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + transportService, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + maxCategoricalFields, + RestRequest.Method.PUT, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof OpenSearchStatusException); + OpenSearchStatusException statusException = (OpenSearchStatusException) e; + assertTrue(statusException.getMessage().contains(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(GetAction.INSTANCE), any(), any()); + } } diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexForecasterActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexForecasterActionHandlerTests.java new file mode 100644 index 000000000..e78b154ea --- /dev/null +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexForecasterActionHandlerTests.java @@ -0,0 +1,1405 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.action.admin.indices.mapping.get; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; +import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexAction; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.routing.AllocationId; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.forecast.rest.handler.IndexForecasterActionHandler; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.rest.handler.AggregationPrep; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class IndexForecasterActionHandlerTests extends AbstractForecasterActionHandlerTestCase { + protected IndexForecasterActionHandler handler; + + public void testCreateOrUpdateConfigException() throws InterruptedException, IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listner = (ActionListener) args[0]; + listner.onFailure(new IllegalArgumentException()); + return null; + }).when(forecastISM).initConfigIndex(any()); + when(forecastISM.doesConfigIndexExist()).thenReturn(false); + + handler = new IndexForecasterActionHandler( + clusterService, + clientMock, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + null, + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue(e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + } + + public void testUpdateConfigException() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(GetAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException()); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.PUT; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + null, + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue(e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(GetAction.INSTANCE), any(), any()); + } + + public void testGetConfigNotExists() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(GetAction.INSTANCE)) { + GetResult notFoundResult = new GetResult("ab", "_doc", UNASSIGNED_SEQ_NO, 0, -1, false, null, null, null); + listener.onResponse((Response) new GetResponse(notFoundResult)); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.PUT; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + null, + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue(e instanceof OpenSearchStatusException); + OpenSearchStatusException statusException = (OpenSearchStatusException) e; + assertTrue(statusException.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(GetAction.INSTANCE), any(), any()); + } + + public void testFaiToParse() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(GetAction.INSTANCE)) { + try { + listener + .onResponse( + (Response) TestHelpers + .createGetResponse(AllocationId.newInitializing(), forecaster.getId(), CommonName.CONFIG_INDEX) + ); + } catch (IOException e) { + LOG.error(e); + } + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.PUT; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof OpenSearchStatusException); + OpenSearchStatusException statusException = (OpenSearchStatusException) e; + assertTrue(statusException.getMessage().contains("Failed to parse config")); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(GetAction.INSTANCE), any(), any()); + } + + public void testSearchHCForecasterException() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException()); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testSearchSingleStreamForecasterException() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setCategoryFields(Arrays.asList()) + .build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException()); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testValidateCategoricalFieldException() throws InterruptedException, IOException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + listener.onFailure(new IllegalArgumentException()); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + String message = String.format(Locale.ROOT, CommonMessages.FAIL_TO_GET_MAPPING_MSG, forecaster.getIndices()); + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + assertTrue("actual: " + message, e.getMessage().contains(message)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(clientSpy, times(1)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testSearchConfigInputException() throws InterruptedException, IOException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + LOG.info(Thread.currentThread().getName()); + StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); + for (StackTraceElement element : stackTrace) { + LOG.info(element); + } + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else { + listener.onFailure(new IllegalArgumentException()); + } + + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(2)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testCheckConfigNameExistsException() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setFeatureAttributes(Arrays.asList()) + .setCategoryFields(Arrays.asList(categoricalField)) + .setNullImputationOption() + .build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + SearchRequest searchRequest = (SearchRequest) request; + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = null; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + BoolQueryBuilder boolQuery = (BoolQueryBuilder) searchRequest.source().query(); + if (boolQuery.must() != null && boolQuery.must().size() > 0) { + // checkConfigNameExists + listener.onFailure(new IllegalArgumentException()); + return; + } else { + // validateAgainstExistingHCConfig + sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + } else { + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, checkConfigNameExists, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(3)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testRedundantNames() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setFeatureAttributes(Arrays.asList()) + .setCategoryFields(Arrays.asList(categoricalField)) + .setNullImputationOption() + .build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + SearchRequest searchRequest = (SearchRequest) request; + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = null; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + BoolQueryBuilder boolQuery = (BoolQueryBuilder) searchRequest.source().query(); + if (boolQuery.must() != null && boolQuery.must().size() > 0) { + // checkConfigNameExists + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } else { + // validateAgainstExistingHCConfig + sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + } else { + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof OpenSearchStatusException); + String error = handler.getDuplicateConfigErrorMsg(forecaster.getName()); + assertTrue("actual: " + e.getMessage(), e.getMessage().contains(error)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, checkConfigNameExists, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(3)).execute(eq(SearchAction.INSTANCE), any(), any()); + } + + public void testIndexConfigVersionConflict() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setFeatureAttributes(Arrays.asList()) + .setCategoryFields(Arrays.asList(categoricalField)) + .setNullImputationOption() + .build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + SearchRequest searchRequest = (SearchRequest) request; + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = null; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } else { + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else if (action.equals(IndexAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException("version conflict")); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + 1L, + 1L, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + String error = "There was a problem updating the config:[" + forecaster.getId() + "]"; + assertTrue("actual: " + e.getMessage(), e.getMessage().contains(error)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, checkConfigNameExists, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(3)).execute(eq(SearchAction.INSTANCE), any(), any()); + // indexConfig + verify(clientSpy, times(1)).execute(eq(IndexAction.INSTANCE), any(), any()); + } + + public void testIndexConfigException() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setFeatureAttributes(Arrays.asList()) + .setCategoryFields(Arrays.asList(categoricalField)) + .setNullImputationOption() + .build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + SearchRequest searchRequest = (SearchRequest) request; + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = null; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } else { + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else if (action.equals(IndexAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException()); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + 1L, + 1L, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof IllegalArgumentException); + assertEquals("actual: " + e.getMessage(), null, e.getMessage()); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, checkConfigNameExists, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(3)).execute(eq(SearchAction.INSTANCE), any(), any()); + // indexConfig + verify(clientSpy, times(1)).execute(eq(IndexAction.INSTANCE), any(), any()); + } + + public void testIndexShardFailure() throws InterruptedException, IOException { + forecaster = TestHelpers.ForecasterBuilder + .newInstance() + .setConfigId(forecasterId) + .setTimeField("timestamp") + .setIndices(ImmutableList.of("test-index")) + .setFeatureAttributes(Arrays.asList()) + .setCategoryFields(Arrays.asList(categoricalField)) + .setNullImputationOption() + .build(); + + IndexResponse.Builder notCreatedResponse = new IndexResponse.Builder(); + notCreatedResponse.setResult(Result.CREATED); + notCreatedResponse.setShardId(new ShardId("index", "_uuid", 0)); + notCreatedResponse.setId("blah"); + notCreatedResponse.setVersion(1L); + + ReplicationResponse.ShardInfo.Failure[] failures = new ReplicationResponse.ShardInfo.Failure[1]; + failures[0] = new ReplicationResponse.ShardInfo.Failure( + new ShardId("index", "_uuid", 1), + null, + new Exception("shard failed"), + RestStatus.GATEWAY_TIMEOUT, + false + ); + notCreatedResponse.setShardInfo(new ShardInfo(2, 1, failures)); + IndexResponse indexResponse = notCreatedResponse.build(); + + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + GetFieldMappingsRequest getMappingsRequest = (GetFieldMappingsRequest) request; + try { + GetFieldMappingsResponse response = null; + if (getMappingsRequest.fields()[0].equals(categoricalField)) { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), categoricalField, "keyword") + ); + } else { + response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(forecaster.getIndices().get(0), "timestamp", "date") + ); + } + + listener.onResponse((Response) response); + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } else if (action.equals(SearchAction.INSTANCE)) { + SearchRequest searchRequest = (SearchRequest) request; + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = null; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } else { + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(randomIntBetween(1, Integer.MAX_VALUE)); + + sections = new SearchResponseSections( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + } + listener + .onResponse( + (Response) new SearchResponse( + sections, + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + } else if (action.equals(IndexAction.INSTANCE)) { + + listener.onResponse((Response) indexResponse); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + method = RestRequest.Method.POST; + + handler = new IndexForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + 1L, + 1L, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + mock(ForecastTaskManager.class), + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof OpenSearchStatusException); + String errorMsg = handler.checkShardsFailure(indexResponse); + assertEquals("actual: " + e.getMessage(), errorMsg, e.getMessage()); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + // once for timestamp, once for categorical field + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + // validateAgainstExistingHCConfig, checkConfigNameExists, validateCategoricalField/searchConfigInputIndices + verify(clientSpy, times(3)).execute(eq(SearchAction.INSTANCE), any(), any()); + // indexConfig + verify(clientSpy, times(1)).execute(eq(IndexAction.INSTANCE), any(), any()); + } + + public void testCreateMappingException() throws InterruptedException, IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listner = (ActionListener) args[0]; + listner.onResponse(new CreateIndexResponse(false, false, "blah")); + return null; + }).when(forecastISM).initConfigIndex(any()); + when(forecastISM.doesIndexExist(anyString())).thenReturn(false); + when(forecastISM.doesAliasExist(anyString())).thenReturn(false); + when(forecastISM.doesConfigIndexExist()).thenReturn(false); + + handler = new IndexForecasterActionHandler( + clusterService, + clientMock, + clientUtil, + mock(TransportService.class), + forecastISM, + forecaster.getId(), + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + null, + searchFeatureDao, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue("actual: " + e, e instanceof OpenSearchStatusException); + assertEquals( + "actual: " + e.getMessage(), + "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", + e.getMessage() + ); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(10, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateForecasterActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateForecasterActionHandlerTests.java new file mode 100644 index 000000000..e179a326f --- /dev/null +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateForecasterActionHandlerTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.action.admin.indices.mapping.get; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.forecast.rest.handler.ValidateForecasterActionHandler; +import org.opensearch.timeseries.model.ValidationAspect; + +public class ValidateForecasterActionHandlerTests extends AbstractForecasterActionHandlerTestCase { + protected ValidateForecasterActionHandler handler; + + public void testCreateOrUpdateConfigException() throws InterruptedException { + doThrow(IllegalArgumentException.class).when(forecastISM).doesConfigIndexExist(); + handler = new ValidateForecasterActionHandler( + clusterService, + clientMock, + clientUtil, + forecastISM, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + searchFeatureDao, + ValidationAspect.FORECASTER.getName(), + clock, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue(e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testValidateTimeFieldException() throws InterruptedException { + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action.equals(GetFieldMappingsAction.INSTANCE)) { + listener.onFailure(new IllegalArgumentException()); + } else { + assertTrue("should not reach here", false); + } + } + }; + NodeClient clientSpy = spy(client); + + handler = new ValidateForecasterActionHandler( + clusterService, + clientSpy, + clientUtil, + forecastISM, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry(), + null, + searchFeatureDao, + ValidationAspect.FORECASTER.getName(), + clock, + settings + ); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> { + assertTrue("should not reach here", false); + inProgressLatch.countDown(); + }, e -> { + assertTrue(e instanceof IllegalArgumentException); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java b/src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java index 0adfb7c0f..ef5afa1c5 100644 --- a/src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java +++ b/src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java @@ -11,38 +11,79 @@ package org.opensearch.ad; -import static org.opensearch.timeseries.TestHelpers.toHttpEntity; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.nio.charset.Charset; +import java.time.Duration; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Map.Entry; +import java.util.TreeMap; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Logger; +import org.junit.Before; import org.opensearch.client.Request; +import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.timeseries.AbstractSyntheticDataTest; -import org.opensearch.timeseries.TestHelpers; -import com.google.common.collect.ImmutableList; import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonParser; public class AbstractADSyntheticDataTest extends AbstractSyntheticDataTest { + protected static class TrainResult { + public String detectorId; + public List data; + // actual index of training data. As we have multiple entities, + // trainTestSplit means how many groups of entities are used for training. + // rawDataTrainTestSplit is the actual index of training data. + public int rawDataTrainTestSplit; + public Duration windowDelay; + public Instant trainTime; + // first data time in data + public Instant firstDataTime; + // last data time in data + public Instant finalDataTime; + + public TrainResult(String detectorId, List data, int rawDataTrainTestSplit, Duration windowDelay, Instant trainTime) { + this.detectorId = detectorId; + this.data = data; + this.rawDataTrainTestSplit = rawDataTrainTestSplit; + this.windowDelay = windowDelay; + this.trainTime = trainTime; + + this.firstDataTime = getDataTime(0); + this.finalDataTime = getDataTime(data.size() - 1); + } + + private Instant getDataTime(int index) { + String finalTimeStr = data.get(index).get("timestamp").getAsString(); + return Instant.ofEpochMilli(Long.parseLong(finalTimeStr)); + } + } + public static final Logger LOG = (Logger) LogManager.getLogger(AbstractADSyntheticDataTest.class); - private static int batchSize = 1000; + protected static final double EPSILON = 1e-3; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + // increase the AD memory percentage. Since enabling jacoco coverage instrumentation, + // the memory is not enough to finish HistoricalAnalysisRestApiIT. + updateClusterSettings(AD_MODEL_MAX_SIZE_PERCENTAGE.getKey(), 0.5); + } protected void runDetectionResult(String detectorId, Instant begin, Instant end, RestClient client, int entitySize) throws IOException, InterruptedException { @@ -59,11 +100,80 @@ protected void runDetectionResult(String detectorId, Instant begin, Instant end, Thread.sleep(50 * entitySize); } - protected List getAnomalyResult(String detectorId, Instant end, int entitySize, RestClient client) - throws InterruptedException { + protected void startHistorical(String detectorId, Instant begin, Instant end, RestClient client, int entitySize) throws IOException, + InterruptedException { + // trigger run in current interval + Request request = new Request( + "POST", + String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_start", detectorId) + ); + request + .setJsonEntity( + String.format(Locale.ROOT, "{ \"start_time\": %d, \"end_time\": %d }", begin.toEpochMilli(), end.toEpochMilli()) + ); + int statusCode = client.performRequest(request).getStatusLine().getStatusCode(); + assert (statusCode >= 200 && statusCode < 300); + + // wait for 50 milliseconds per entity before next query + Thread.sleep(50 * entitySize); + } + + protected Map preview(String detector, Instant begin, Instant end, RestClient client) throws IOException, + InterruptedException { + LOG.info("preview detector {}", detector); + // trigger run in current interval + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/_preview"); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{ \"period_start\": %d, \"period_end\": %d, \"detector\": %s }", + begin.toEpochMilli(), + end.toEpochMilli(), + detector + ) + ); + Response response = client.performRequest(request); + int statusCode = response.getStatusLine().getStatusCode(); + assert (statusCode >= 200 && statusCode < 300); + + return entityAsMap(response); + } + + protected Map previewWithFailure(String detector, Instant begin, Instant end, RestClient client) throws IOException, + InterruptedException { + // trigger run in current interval + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/_preview"); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{ \"period_start\": %d, \"period_end\": %d, \"detector\": %s }", + begin.toEpochMilli(), + end.toEpochMilli(), + detector + ) + ); + Response response = client.performRequest(request); + int statusCode = response.getStatusLine().getStatusCode(); + assert (statusCode == 400); + + return entityAsMap(response); + } + + protected List getAnomalyResult( + String detectorId, + Instant end, + int entitySize, + RestClient client, + boolean approximateDataEndTime, + long intervalMillis + ) throws InterruptedException { Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/results/_search"); - String jsonTemplate = "{\n" + String jsonTemplatePrefix = "{\n" + " \"query\": {\n" + " \"bool\": {\n" + " \"filter\": [\n" @@ -81,19 +191,38 @@ protected List getAnomalyResult(String detectorId, Instant end, int + " },\n" + " {\n" + " \"range\": {\n" - + " \"data_end_time\": {\n" - + " \"gte\": %d,\n" - + " \"lte\": %d\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + " }\n" - + "}"; + + " \"data_end_time\": {\n"; + + StringBuilder jsonTemplate = new StringBuilder(); + jsonTemplate.append(jsonTemplatePrefix); + + if (approximateDataEndTime) { + // we may get two interval results if using gte + jsonTemplate.append(" \"gt\": %d,\n \"lte\": %d\n"); + } else { + jsonTemplate.append(" \"gte\": %d,\n \"lte\": %d\n"); + } + + jsonTemplate + .append( + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}" + ); long dateEndTime = end.toEpochMilli(); - String formattedJson = String.format(Locale.ROOT, jsonTemplate, detectorId, dateEndTime, dateEndTime); + String formattedJson = null; + + if (approximateDataEndTime) { + formattedJson = String.format(Locale.ROOT, jsonTemplate.toString(), detectorId, dateEndTime - intervalMillis, dateEndTime); + } else { + formattedJson = String.format(Locale.ROOT, jsonTemplate.toString(), detectorId, dateEndTime, dateEndTime); + } + request.setJsonEntity(formattedJson); // wait until results are available @@ -135,6 +264,7 @@ protected List getAnomalyResult(String detectorId, Instant end, int String matchAll = "{\n" + " \"size\": 1000,\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; request.setJsonEntity(matchAll); JsonArray hits = getHits(client, request); + LOG.info("Query: {}", formattedJson); LOG.info("match all result: {}", hits); } catch (Exception e) { LOG.warn("Exception while waiting for match all result", e); @@ -143,12 +273,26 @@ protected List getAnomalyResult(String detectorId, Instant end, int return new ArrayList<>(); } - protected double getAnomalyGrade(JsonObject source) { + protected List getRealTimeAnomalyResult(String detectorId, Instant end, int entitySize, RestClient client) + throws InterruptedException { + return getAnomalyResult(detectorId, end, entitySize, client, false, 0); + } + + public double getAnomalyGrade(JsonObject source) { return source.get("anomaly_grade").getAsDouble(); } - protected String getEntity(JsonObject source) { - return source.get("entity").getAsJsonArray().get(0).getAsJsonObject().get("value").getAsString(); + public double getConfidence(JsonObject source) { + return source.get("confidence").getAsDouble(); + } + + public String getEntity(JsonObject source) { + JsonElement element = source.get("entity"); + if (element == null) { + // single stream + return "dummy"; + } + return element.getAsJsonArray().get(0).getAsJsonObject().get("value").getAsString(); } /** @@ -157,7 +301,7 @@ protected String getEntity(JsonObject source) { * @param defaultVal default anomaly time. Usually data end time. * @return anomaly event time. */ - protected Instant getAnomalyTime(JsonObject source, Instant defaultVal) { + public Instant getAnomalyTime(JsonObject source, Instant defaultVal) { JsonElement anomalyTime = source.get("approx_anomaly_start_time"); if (anomalyTime != null) { long epochhMillis = anomalyTime.getAsLong(); @@ -166,6 +310,39 @@ protected Instant getAnomalyTime(JsonObject source, Instant defaultVal) { return defaultVal; } + public JsonObject getFeature(JsonObject source, int index) { + JsonArray featureDataArray = source.getAsJsonArray("feature_data"); + + // Get the index element from the JsonArray + return featureDataArray.get(index).getAsJsonObject(); + } + + public JsonObject getImputed(JsonObject source, int index) { + JsonArray featureDataArray = source.getAsJsonArray("feature_imputed"); + if (featureDataArray == null) { + return null; + } + + // Get the index element from the JsonArray + return featureDataArray.get(index).getAsJsonObject(); + } + + protected JsonObject getImputed(JsonObject source, String featureId) { + JsonArray featureDataArray = source.getAsJsonArray("feature_imputed"); + if (featureDataArray == null) { + return null; + } + + for (int i = 0; i < featureDataArray.size(); i++) { + // Get the index element from the JsonArray + JsonObject jsonObject = featureDataArray.get(i).getAsJsonObject(); + if (jsonObject.get("feature_id").getAsString().equals(featureId)) { + return jsonObject; + } + } + return null; + } + protected String createDetector(RestClient client, String detectorJson) throws Exception { Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); @@ -247,9 +424,11 @@ protected void waitForInitDetector(String detectorId, RestClient client) throws * @param client OpenSearch Client * @param end date end time of the most recent detection period * @param entitySize the number of entity results to wait for + * @return initial result * @throws Exception when failing to query/indexing from/to OpenSearch */ - protected void simulateWaitForInitDetector(String detectorId, RestClient client, Instant end, int entitySize) throws Exception { + protected List simulateWaitForInitDetector(String detectorId, RestClient client, Instant end, int entitySize) + throws Exception { long startTime = System.currentTimeMillis(); long duration = 0; @@ -257,45 +436,42 @@ protected void simulateWaitForInitDetector(String detectorId, RestClient client, Thread.sleep(1_000); - List sourceList = getAnomalyResult(detectorId, end, entitySize, client); + List sourceList = getRealTimeAnomalyResult(detectorId, end, entitySize, client); if (sourceList.size() > 0 && getAnomalyGrade(sourceList.get(0)) >= 0) { - break; + return sourceList; } duration = System.currentTimeMillis() - startTime; } while (duration <= 60_000); - assertTrue("time out while waiting for initing detector", duration <= 60_000); + assertTrue("time out while waiting for initing detector", false); + return null; } - protected void bulkIndexData(List data, String datasetName, RestClient client, String mapping, int ingestDataSize) - throws Exception { - createIndex(datasetName, client, mapping); - StringBuilder bulkRequestBuilder = new StringBuilder(); - LOG.info("data size {}", data.size()); - int count = 0; - int pickedIngestSize = Math.min(ingestDataSize, data.size()); - for (int i = 0; i < pickedIngestSize; i++) { - bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); - bulkRequestBuilder.append(data.get(i).toString()).append("\n"); - count++; - if (count >= batchSize || i == pickedIngestSize - 1) { - count = 0; - TestHelpers - .makeRequest( - client, - "POST", - "_bulk?refresh=true", - null, - toHttpEntity(bulkRequestBuilder.toString()), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Thread.sleep(1_000); + protected List waitForHistoricalDetector( + String detectorId, + RestClient client, + Instant end, + int entitySize, + int intervalMillis + ) throws Exception { + + long startTime = System.currentTimeMillis(); + long duration = 0; + do { + + Thread.sleep(1_000); + + List sourceList = getAnomalyResult(detectorId, end, entitySize, client, true, intervalMillis); + if (sourceList.size() > 0 && getAnomalyGrade(sourceList.get(0)) >= 0) { + return sourceList; } - } - waitAllSyncheticDataIngested(data.size(), datasetName, client); - LOG.info("data ingestion complete"); + duration = System.currentTimeMillis() - startTime; + } while (duration <= 60_000); + + assertTrue("time out while waiting for historical detector to finish", false); + return null; } /** @@ -313,7 +489,7 @@ protected void simulateStartDetector(String detectorId, Instant begin, Instant e runDetectionResult(detectorId, begin, end, client, entitySize); } - protected int isAnomaly(Instant time, List> labels) { + public int isAnomaly(Instant time, List> labels) { for (int i = 0; i < labels.size(); i++) { Entry window = labels.get(i); if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { @@ -331,4 +507,124 @@ protected List getData(String datasetFileName) throws Exception { jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); return list; } + + protected Instant dataToExecutionTime(Instant instant, Duration windowDelay) { + return instant.plus(windowDelay); + } + + /** + * Assume the data is sorted in time. The method look up and below startIndex + * and return how many timestamps equal to timestampStr. + * @param startIndex where to start look for timestamp + * @return how many timestamps equal to timestampStr + */ + protected int findGivenTimeEntities(int startIndex, List data) { + String timestampStr = data.get(startIndex).get("timestamp").getAsString(); + int count = 1; + for (int i = startIndex - 1; i >= 0; i--) { + String trainTimeStr = data.get(i).get("timestamp").getAsString(); + if (trainTimeStr.equals(timestampStr)) { + count++; + } else { + break; + } + } + for (int i = startIndex + 1; i < data.size(); i++) { + String trainTimeStr = data.get(i).get("timestamp").getAsString(); + if (trainTimeStr.equals(timestampStr)) { + count++; + } else { + break; + } + } + return count; + } + + /** + * + * @param beginTimeStampAsString data start time in string + * @param entityMap a map to record the number of times we have seen a timestamp. Used to detect missing values. + * @param windowDelay ingestion delay + * @param intervalMinutes detector interval + * @param detectorId detector Id + * @param client RestFul client + * @param numberOfEntities the number of entities. + * @return whether we erred out. + */ + protected boolean scoreOneResult( + String beginTimeStampAsString, + TreeMap entityMap, + Duration windowDelay, + int intervalMinutes, + String detectorId, + RestClient client, + int numberOfEntities + ) { + Integer newCount = entityMap.compute(beginTimeStampAsString, (key, oldValue) -> (oldValue == null) ? 1 : oldValue + 1); + if (newCount > 1) { + // we have seen this timestamp before. Without this line, we will get rcf IllegalArgumentException about out of order tuples + return false; + } + Instant begin = dataToExecutionTime(Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)), windowDelay); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + runDetectionResult(detectorId, begin, end, client, numberOfEntities); + } catch (Exception e) { + LOG.error("failed to run detection result", e); + return true; + } + return false; + } + + protected List startRealTimeDetector( + TrainResult trainResult, + int numberOfEntities, + int intervalMinutes, + boolean imputeEnabled + ) throws Exception { + Instant executeBegin = dataToExecutionTime(trainResult.trainTime, trainResult.windowDelay); + Instant executeEnd = executeBegin.plus(intervalMinutes, ChronoUnit.MINUTES); + Instant dataEnd = trainResult.trainTime.plus(intervalMinutes, ChronoUnit.MINUTES); + + LOG.info("start detector {}, dataStart {}, dataEnd {}", trainResult.detectorId, trainResult.trainTime, dataEnd); + simulateStartDetector(trainResult.detectorId, executeBegin, executeEnd, client(), numberOfEntities); + int resultsToWait = numberOfEntities; + if (!imputeEnabled) { + resultsToWait = findGivenTimeEntities(trainResult.rawDataTrainTestSplit - 1, trainResult.data); + } + LOG.info("wait for initting detector {}. {} results are expected.", trainResult.detectorId, resultsToWait); + return simulateWaitForInitDetector(trainResult.detectorId, client(), dataEnd, resultsToWait); + } + + protected List startHistoricalDetector( + TrainResult trainResult, + int numberOfEntities, + int intervalMinutes, + boolean imputeEnabled + ) throws Exception { + LOG.info("start historical detector {}", trainResult.detectorId); + startHistorical(trainResult.detectorId, trainResult.firstDataTime, trainResult.finalDataTime, client(), numberOfEntities); + int resultsToWait = numberOfEntities; + if (!imputeEnabled) { + findGivenTimeEntities(trainResult.data.size() - 1, trainResult.data); + } + LOG + .info( + "wait for historical detector {} at {}. {} results are expected.", + trainResult.detectorId, + trainResult.finalDataTime, + resultsToWait + ); + return waitForHistoricalDetector( + trainResult.detectorId, + client(), + trainResult.finalDataTime, + resultsToWait, + intervalMinutes * 60000 + ); + } + + public static boolean areDoublesEqual(double d1, double d2) { + return Math.abs(d1 - d2) < EPSILON; + } } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index fe9776f11..97faa29f5 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -262,7 +262,7 @@ public void setup() throws Exception { when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(false); - detector = TestHelpers.randomAnomalyDetectorWithEmptyFeature(); + detector = TestHelpers.randomAnomalyDetector("timestamp", "sourceIndex"); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(detector)); diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java new file mode 100644 index 000000000..3eec05657 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java @@ -0,0 +1,301 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.NavigableSet; +import java.util.TreeMap; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +import com.google.gson.JsonObject; + +public abstract class AbstractMissingSingleFeatureTestCase extends MissingIT { + @Override + protected String genDetector( + int trainTestSplit, + long windowDelayMinutes, + boolean hc, + ImputationMethod imputation, + long trainTimeMillis + ) { + StringBuilder sb = new StringBuilder(); + // common part + sb + .append( + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"avg\": { \"field\": \"data\" } } } }" + + "], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"history\": %d," + ); + + if (windowDelayMinutes > 0) { + sb + .append( + String + .format( + Locale.ROOT, + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}},", + windowDelayMinutes + ) + ); + } + if (hc) { + sb.append("\"category_field\": [\"%s\"], "); + } + + switch (imputation) { + case ZERO: + sb.append("\"imputation_option\": { \"method\": \"zero\" },"); + break; + case PREVIOUS: + sb.append("\"imputation_option\": { \"method\": \"previous\" },"); + break; + case FIXED_VALUES: + sb + .append( + "\"imputation_option\": { \"method\": \"fixed_values\", \"defaultFill\": [{ \"feature_name\" : \"feature 1\", \"data\": 1 }] }," + ); + break; + } + // end + sb.append("\"schema_version\": 0}"); + + if (hc) { + return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField); + } else { + return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1); + } + + } + + @Override + protected AbstractSyntheticDataTest.GenData genData( + int trainTestSplit, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE missingMode + ) throws Exception { + return genUniformSingleFeatureData( + intervalMinutes, + trainTestSplit, + numberOfEntities, + categoricalField, + missingMode, + continuousImputeStartIndex, + continuousImputeEndIndex, + randomDoubles + ); + } + + @Override + protected double extractFeatureValue(JsonObject source) { + JsonObject feature0 = getFeature(source, 0); + return feature0.get("data").getAsDouble(); + } + + protected void verifyGrade(Integer testIndex, ImputationMethod imputation, double anomalyGrade) { + if (testIndex == continuousImputeStartIndex || testIndex == continuousImputeEndIndex + 1) { + switch (imputation) { + case ZERO: + assertTrue(Double.compare(anomalyGrade, 0) > 0); + break; + case PREVIOUS: + assertEquals(anomalyGrade, 0, EPSILON); + break; + case FIXED_VALUES: + assertTrue(Double.compare(anomalyGrade, 0) > 0); + break; + default: + assertTrue(false); + break; + } + } else { + assertEquals("testIndex: " + testIndex, 0, anomalyGrade, EPSILON); + } + } + + protected boolean verifyConfidence(Integer testIndex, double confidence, Double lastConfidence) { + if (lastConfidence == null) { + return false; + } + + // we will see confidence increasing again after some point in shingle size (default 8) + if (testIndex <= continuousImputeStartIndex || testIndex >= continuousImputeEndIndex + 8) { + assertTrue( + String.format(Locale.ROOT, "confidence: %f, lastConfidence: %f, testIndex: %d", confidence, lastConfidence, testIndex), + Double.compare(confidence, lastConfidence) >= 0 + ); + } else if (testIndex > continuousImputeStartIndex && testIndex <= continuousImputeEndIndex) { + assertTrue( + String.format(Locale.ROOT, "confidence: %f, lastConfidence: %f, testIndex: %d", confidence, lastConfidence, testIndex), + Double.compare(confidence, lastConfidence) <= 0 + ); + } + return true; + } + + @Override + protected void runTest( + long firstDataStartTime, + AbstractSyntheticDataTest.GenData dataGenerated, + Duration windowDelay, + String detectorId, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE mode, + ImputationMethod imputation, + int numberOfMissingToCheck, + boolean realTime + ) { + int errors = 0; + List data = dataGenerated.data; + long lastDataStartTime = data.get(data.size() - 1).get("timestamp").getAsLong(); + long dataStartTime = firstDataStartTime + intervalMillis; + NavigableSet missingTimestamps = dataGenerated.missingTimestamps; + NavigableSet> missingEntities = dataGenerated.missingEntities; + + // we might miss timestamps at the end + if (mode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP + || mode == AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE) { + lastDataStartTime = Math.max(missingTimestamps.last(), lastDataStartTime); + } else if (mode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_ENTITY) { + lastDataStartTime = Math.max(missingEntities.last().getLeft(), lastDataStartTime); + } + // an entity might have missing values (e.g., at timestamp 1694713200000). + // Use a map to record the number of times we have seen them. + // data start time -> the number of entities + TreeMap entityMap = new TreeMap<>(); + + int missingIndex = 0; + Map lastConfidence = new HashMap<>(); + + // look two default shingle size after imputation area + long continuousModeStopTime = dataStartTime + intervalMillis * (continuousImputeEndIndex + 16); + + // exit when reaching last date time or we have seen at least three missing values. + // In continuous impute mode, we may want to read a few more points to check if confidence increases or not + LOG.info("lastDataStartTime: {}, dataStartTime: {}", lastDataStartTime, dataStartTime); + // test data 0 is used trigger cold start + int testIndex = 1; + while (lastDataStartTime >= dataStartTime && (missingIndex <= numberOfMissingToCheck && continuousModeStopTime >= dataStartTime)) { + // no need to call _run api in each interval in historical case + if (realTime + && scoreOneResult( + String.valueOf(dataStartTime), + entityMap, + windowDelay, + intervalMinutes, + detectorId, + client(), + numberOfEntities + )) { + errors++; + } + + LOG.info("test index: {}", testIndex); + + Instant begin = Instant.ofEpochMilli(dataStartTime); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + List sourceList = null; + if (realTime) { + sourceList = getRealTimeAnomalyResult(detectorId, end, numberOfEntities, client()); + } else { + sourceList = getAnomalyResult(detectorId, end, numberOfEntities, client(), true, intervalMillis); + } + + assertTrue( + String + .format( + Locale.ROOT, + "the number of results is %d at %s, expected %d ", + sourceList.size(), + end.toEpochMilli(), + numberOfEntities + ), + sourceList.size() == numberOfEntities + ); + + // used to track if any entity within a timestamp has imputation and then we increment + // missingIndex outside the loop. Used in MISSING_TIMESTAMP mode. + boolean imputed = false; + for (int j = 0; j < numberOfEntities; j++) { + JsonObject source = sourceList.get(j); + JsonObject feature0 = getFeature(source, 0); + double dataValue = feature0.get("data").getAsDouble(); + + JsonObject imputed0 = getImputed(source, 0); + + String entity = getEntity(source); + + if (mode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP && missingTimestamps.contains(dataStartTime)) { + verifyImputation(imputation, lastSeen, dataValue, imputed0, entity); + imputed = true; + } else if (mode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_ENTITY + && missingEntities.contains(Pair.of(dataStartTime, entity))) { + verifyImputation(imputation, lastSeen, dataValue, imputed0, entity); + missingIndex++; + } else if (mode == AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE) { + int imputeIndex = getIndex(missingTimestamps, dataStartTime); + double grade = getAnomalyGrade(source); + verifyGrade(testIndex, imputation, grade); + + if (imputeIndex >= 0) { + imputed = true; + } + + double confidence = getConfidence(source); + verifyConfidence(testIndex, confidence, lastConfidence.get(entity)); + lastConfidence.put(entity, confidence); + } else { + assertEquals(null, imputed0); + } + + lastSeen.put(entity, dataValue); + } + if (imputed) { + missingIndex++; + } + } catch (Exception e) { + errors++; + LOG.error("failed to get detection results", e); + } finally { + testIndex++; + } + + dataStartTime += intervalMillis; + } + + // at least numberOfMissingToCheck missing value imputation is seen + assertTrue( + String.format(Locale.ROOT, "missingIndex %d, numberOfMissingToCheck %d", missingIndex, numberOfMissingToCheck), + missingIndex >= numberOfMissingToCheck + ); + assertTrue(errors < maxError); + } + + /** + * + * @param element type + * @param set ordered set + * @param element element to compare + * @return the index of the element that is less than or equal to the given element. + */ + protected static int getIndex(NavigableSet set, E element) { + if (!set.contains(element)) { + return -1; + } + return set.headSet(element, true).size() - 1; + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractRuleModelPerfTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractRuleModelPerfTestCase.java new file mode 100644 index 000000000..30f07de3d --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractRuleModelPerfTestCase.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.commons.lang3.tuple.Triple; + +import com.google.gson.JsonObject; + +public abstract class AbstractRuleModelPerfTestCase extends AbstractRuleTestCase { + protected void verifyTestResults( + Triple, Integer, Map>> testResults, + Map>> anomalies, + Map minPrecision, + Map minRecall, + int maxError + ) { + Map resultMap = testResults.getLeft(); + Map> foundWindows = testResults.getRight(); + + for (Entry entry : resultMap.entrySet()) { + String entity = entry.getKey(); + double[] testResultsArray = entry.getValue(); + double positives = testResultsArray[0]; + double truePositives = testResultsArray[1]; + + // precision = predicted anomaly points that are true / predicted anomaly points + double precision = positives > 0 ? truePositives / positives : 0; + double minPrecisionValue = minPrecision.getOrDefault(entity, .4); + assertTrue( + String + .format( + Locale.ROOT, + "precision expected at least %f but got %f. positives %f, truePositives %f", + minPrecisionValue, + precision, + positives, + truePositives + ), + precision >= minPrecisionValue + ); + + // recall = windows containing predicted anomaly points / total anomaly windows + int anomalyWindow = anomalies.getOrDefault(entity, new ArrayList<>()).size(); + int foundWindowSize = foundWindows.getOrDefault(entity, new HashSet<>()).size(); + double recall = anomalyWindow > 0 ? foundWindowSize * 1.0d / anomalyWindow : 0; + double minRecallValue = minRecall.getOrDefault(entity, .7); + assertTrue( + String + .format( + Locale.ROOT, + "recall should be at least %f but got %f. anomalyWindow %d, foundWindowSize %d ", + minRecallValue, + recall, + anomalyWindow, + foundWindowSize + ), + recall >= minRecallValue + ); + + LOG.info("Entity {}, Precision: {}, Window recall: {}", entity, precision, recall); + } + + int errors = testResults.getMiddle(); + assertTrue(errors <= maxError); + } + + protected void analyzeResults( + Map>> anomalies, + Map res, + Map> foundWindow, + String beginTimeStampAsString, + int entitySize, + Instant begin, + List sourceList + ) { + assertTrue( + String + .format( + Locale.ROOT, + "the number of results is %d at %s, expected %d ", + sourceList.size(), + beginTimeStampAsString, + entitySize + ), + sourceList.size() == entitySize + ); + for (int j = 0; j < entitySize; j++) { + JsonObject source = sourceList.get(j); + double anomalyGrade = getAnomalyGrade(source); + assertTrue("anomalyGrade cannot be negative", anomalyGrade >= 0); + if (anomalyGrade > 0) { + String entity = getEntity(source); + double[] entityResult = res.computeIfAbsent(entity, key -> new double[] { 0, 0 }); + // positive++ + entityResult[0]++; + Instant anomalyTime = getAnomalyTime(source, begin); + LOG.info("Found anomaly: entity {}, time {} result {}.", entity, anomalyTime, source); + int anomalyWindow = isAnomaly(anomalyTime, anomalies.getOrDefault(entity, new ArrayList<>())); + if (anomalyWindow != -1) { + LOG.info("True anomaly: entity {}, time {}.", entity, begin); + // truePositives++; + entityResult[1]++; + Set window = foundWindow.computeIfAbsent(entity, key -> new HashSet<>()); + window.add(anomalyWindow); + } + } + } + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java index a5d6c8686..941af35a5 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java @@ -5,39 +5,32 @@ package org.opensearch.ad.e2e; +import java.io.File; +import java.io.FileReader; +import java.nio.charset.Charset; import java.time.Duration; import java.time.Instant; -import java.time.temporal.ChronoUnit; +import java.time.format.DateTimeFormatter; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.TreeMap; +import java.util.Map; +import java.util.Map.Entry; import org.opensearch.ad.AbstractADSyntheticDataTest; import org.opensearch.client.RestClient; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; +import com.google.gson.JsonParser; -public class AbstractRuleTestCase extends AbstractADSyntheticDataTest { - - protected static class TrainResult { - String detectorId; - List data; - // actual index of training data. As we have multiple entities, - // trainTestSplit means how many groups of entities are used for training. - // rawDataTrainTestSplit is the actual index of training data. - int rawDataTrainTestSplit; - Duration windowDelay; - - public TrainResult(String detectorId, List data, int rawDataTrainTestSplit, Duration windowDelay) { - this.detectorId = detectorId; - this.data = data; - this.rawDataTrainTestSplit = rawDataTrainTestSplit; - this.windowDelay = windowDelay; - } - } +public abstract class AbstractRuleTestCase extends AbstractADSyntheticDataTest { + String categoricalField = "componentName"; /** - * Ingest all of the data in file datasetName + * Ingest all of the data in file datasetName and create detector * * @param datasetName data set file name * @param intervalMinutes detector interval @@ -47,14 +40,64 @@ public TrainResult(String detectorId, List data, int rawDataTrainTes * @return TrainResult for the following method calls * @throws Exception failing to ingest data */ - protected TrainResult ingestTrainData( + protected TrainResult ingestTrainDataAndCreateDetector( String datasetName, int intervalMinutes, int numberOfEntities, int trainTestSplit, boolean useDateNanos ) throws Exception { - return ingestTrainData(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1); + return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1); + } + + protected TrainResult ingestTrainDataAndCreateDetector( + String datasetName, + int intervalMinutes, + int numberOfEntities, + int trainTestSplit, + boolean useDateNanos, + int ingestDataSize + ) throws Exception { + TrainResult trainResult = ingestTrainData( + datasetName, + intervalMinutes, + numberOfEntities, + trainTestSplit, + useDateNanos, + ingestDataSize + ); + + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + String detectorId = createDetector(client(), detector); + LOG.info("Created detector {}", detectorId); + trainResult.detectorId = detectorId; + + return trainResult; + } + + protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult) { + String detector = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"transform._doc_count\" } } } }" + + "], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"history\": %d," + + "\"schema_version\": 0," + + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [{\"feature_name\": \"feature 1\", \"threshold_type\": \"actual_over_expected_ratio\", \"operator\": \"lte\", \"value\": 0.3}, " + + "{\"feature_name\": \"feature 1\", \"threshold_type\": \"expected_over_actual_ratio\", \"operator\": \"lte\", \"value\": 0.3}" + + "]}]" + + "}", + datasetName, + intervalMinutes, + categoricalField, + trainResult.windowDelay.toMinutes(), + trainTestSplit - 1 + ); + return detector; } protected TrainResult ingestTrainData( @@ -70,7 +113,6 @@ protected TrainResult ingestTrainData( List data = getData(dataFileName); RestClient client = client(); - String categoricalField = "componentName"; String mapping = String .format( Locale.ROOT, @@ -105,111 +147,29 @@ protected TrainResult ingestTrainData( */ long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes(); Duration windowDelay = Duration.ofMinutes(windowDelayMinutes); - - String detector = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"transform._doc_count\" } } } }" - + "], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"category_field\": [\"%s\"], " - + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," - + "\"history\": %d," - + "\"schema_version\": 0," - + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [{\"feature_name\": \"feature 1\", \"threshold_type\": \"actual_over_expected_ratio\", \"operator\": \"lte\", \"value\": 0.3}, " - + "{\"feature_name\": \"feature 1\", \"threshold_type\": \"expected_over_actual_ratio\", \"operator\": \"lte\", \"value\": 0.3}" - + "]}]" - + "}", - datasetName, - intervalMinutes, - categoricalField, - windowDelayMinutes, - trainTestSplit - 1 - ); - String detectorId = createDetector(client, detector); - LOG.info("Created detector {}", detectorId); - - Instant executeBegin = dataToExecutionTime(trainTime, windowDelay); - Instant executeEnd = executeBegin.plus(intervalMinutes, ChronoUnit.MINUTES); - Instant dataEnd = trainTime.plus(intervalMinutes, ChronoUnit.MINUTES); - - LOG.info("start detector {}", detectorId); - simulateStartDetector(detectorId, executeBegin, executeEnd, client, numberOfEntities); - int resultsToWait = findTrainTimeEntities(rawDataTrainTestSplit - 1, data); - LOG.info("wait for initting detector {}. {} results are expected.", detectorId, resultsToWait); - simulateWaitForInitDetector(detectorId, client, dataEnd, resultsToWait); - - return new TrainResult(detectorId, data, rawDataTrainTestSplit, windowDelay); + return new TrainResult(null, data, rawDataTrainTestSplit, windowDelay, trainTime); } - /** - * Assume the data is sorted in time. The method look up and below startIndex - * and return how many timestamps equal to timestampStr. - * @param startIndex where to start look for timestamp - * @return how many timestamps equal to timestampStr - */ - protected int findTrainTimeEntities(int startIndex, List data) { - String timestampStr = data.get(startIndex).get("timestamp").getAsString(); - int count = 1; - for (int i = startIndex - 1; i >= 0; i--) { - String trainTimeStr = data.get(i).get("timestamp").getAsString(); - if (trainTimeStr.equals(timestampStr)) { - count++; - } else { - break; + public Map>> getAnomalyWindowsMap(String labelFileName) throws Exception { + JsonObject jsonObject = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(labelFileName).toURI()), Charset.defaultCharset())) + .getAsJsonObject(); + + Map>> map = new HashMap<>(); + for (Map.Entry entry : jsonObject.entrySet()) { + List> anomalies = new ArrayList<>(); + JsonElement value = entry.getValue(); + if (value.isJsonArray()) { + for (JsonElement elem : value.getAsJsonArray()) { + JsonElement beginElement = elem.getAsJsonArray().get(0); + JsonElement endElement = elem.getAsJsonArray().get(1); + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(beginElement.getAsString())); + Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(endElement.getAsString())); + anomalies.add(new SimpleEntry<>(begin, end)); + } } + map.put(entry.getKey(), anomalies); } - for (int i = startIndex + 1; i < data.size(); i++) { - String trainTimeStr = data.get(i).get("timestamp").getAsString(); - if (trainTimeStr.equals(timestampStr)) { - count++; - } else { - break; - } - } - return count; + return map; } - - protected Instant dataToExecutionTime(Instant instant, Duration windowDelay) { - return instant.plus(windowDelay); - } - - /** - * - * @param testData current data to score - * @param entityMap a map to record the number of times we have seen a timestamp. Used to detect missing values. - * @param windowDelay ingestion delay - * @param intervalMinutes detector interval - * @param detectorId detector Id - * @param client RestFul client - * @param numberOfEntities the number of entities. - * @return whether we erred out. - */ - protected boolean scoreOneResult( - JsonObject testData, - TreeMap entityMap, - Duration windowDelay, - int intervalMinutes, - String detectorId, - RestClient client, - int numberOfEntities - ) { - String beginTimeStampAsString = testData.get("timestamp").getAsString(); - Integer newCount = entityMap.compute(beginTimeStampAsString, (key, oldValue) -> (oldValue == null) ? 1 : oldValue + 1); - if (newCount > 1) { - // we have seen this timestamp before. Without this line, we will get rcf IllegalArgumentException about out of order tuples - return false; - } - Instant begin = dataToExecutionTime(Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)), windowDelay); - Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - runDetectionResult(detectorId, begin, end, client, numberOfEntities); - } catch (Exception e) { - LOG.error("failed to run detection result", e); - return true; - } - return false; - } - } diff --git a/src/test/java/org/opensearch/ad/e2e/HistoricalMissingSingleFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/HistoricalMissingSingleFeatureIT.java new file mode 100644 index 000000000..997292c22 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/HistoricalMissingSingleFeatureIT.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +public class HistoricalMissingSingleFeatureIT extends AbstractMissingSingleFeatureTestCase { + public void testSingleStream() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartHistoricalDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + false, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + false + ); + } + + public void testHCFixed() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.FIXED_VALUES; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartHistoricalDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + false + ); + } + + public void testHCPrevious() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.PREVIOUS; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartHistoricalDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + false + ); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/HistoricalRuleModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/HistoricalRuleModelPerfIT.java new file mode 100644 index 000000000..a5c19d4c7 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/HistoricalRuleModelPerfIT.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; + +import org.apache.commons.lang3.tuple.Triple; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.client.RestClient; + +import com.google.gson.JsonObject; + +public class HistoricalRuleModelPerfIT extends AbstractRuleModelPerfTestCase { + static final Logger LOG = (Logger) LogManager.getLogger(HistoricalRuleModelPerfIT.class); + + public void testRule() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + // there are 8 entities in the data set. Each one needs 1500 rows as training data. + Map minPrecision = new HashMap<>(); + minPrecision.put("Phoenix", 0.4); + minPrecision.put("Scottsdale", 0.5); + Map minRecall = new HashMap<>(); + minRecall.put("Phoenix", 0.9); + minRecall.put("Scottsdale", 0.6); + verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20); + } + } + + public void verifyRule( + String datasetName, + int intervalMinutes, + int numberOfEntities, + int trainTestSplit, + Map minPrecision, + Map minRecall, + int maxError + ) throws Exception { + + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + Map>> anomalies = getAnomalyWindowsMap(labelFileName); + + TrainResult trainResult = ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, false); + startHistoricalDetector(trainResult, numberOfEntities, intervalMinutes, false); + + Triple, Integer, Map>> results = getTestResults( + trainResult.detectorId, + trainResult.data, + trainResult.rawDataTrainTestSplit, + intervalMinutes, + anomalies, + client(), + numberOfEntities, + trainResult.windowDelay + ); + verifyTestResults(results, anomalies, minPrecision, minRecall, maxError); + } + + private Triple, Integer, Map>> getTestResults( + String detectorId, + List data, + int rawTrainTestSplit, + int intervalMinutes, + Map>> anomalies, + RestClient client, + int numberOfEntities, + Duration windowDelay + ) throws Exception { + + Map res = new HashMap<>(); + int errors = 0; + // an entity might have missing values (e.g., at timestamp 1694713200000). + // Use a map to record the number of times we have seen them. + // data start time -> the number of entities + TreeMap entityMap = new TreeMap<>(); + // historical won't detect the last point as + // 1) ParseUtils.batchFeatureQuery uses left closed right open range to construct query + // 2) ADBatchTaskRunner.runNextPiece will stop when the next piece start time is larger than or equal to dataEnd time + // 3) ADBatchTaskRunner.getDateRangeOfSourceData will make data start/end time equal to min/max data time. + for (int i = rawTrainTestSplit; i < data.size() - numberOfEntities; i++) { + entityMap.compute(data.get(i).get("timestamp").getAsString(), (key, oldValue) -> (oldValue == null) ? 1 : oldValue + 1); + } + + // hash set to dedup + Map> foundWindow = new HashMap<>(); + long intervalMillis = intervalMinutes * 60000; + + // Iterate over the TreeMap in ascending order of keys + for (Map.Entry entry : entityMap.entrySet()) { + String beginTimeStampAsString = entry.getKey(); + int entitySize = entry.getValue(); + Instant begin = Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + List sourceList = getAnomalyResult(detectorId, end, entitySize, client, true, intervalMillis); + analyzeResults(anomalies, res, foundWindow, beginTimeStampAsString, entitySize, begin, sourceList); + } catch (Exception e) { + errors++; + LOG.error("failed to get detection results", e); + } + } + return Triple.of(res, errors, foundWindow); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/MissingIT.java b/src/test/java/org/opensearch/ad/e2e/MissingIT.java new file mode 100644 index 000000000..21f216819 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/MissingIT.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.BeforeClass; +import org.opensearch.ad.AbstractADSyntheticDataTest; +import org.opensearch.client.RestClient; +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +import com.google.gson.JsonObject; + +public abstract class MissingIT extends AbstractADSyntheticDataTest { + protected static double min = 200.0; + protected static double max = 240.0; + protected static int dataSize = 400; + + protected static List randomDoubles; + protected static String datasetName = "missing"; + + protected int intervalMinutes = 10; + public long intervalMillis = intervalMinutes * 60000L; + protected String categoricalField = "componentName"; + protected int maxError = 20; + protected int trainTestSplit = 100; + + public int continuousImputeStartIndex = 11; + public int continuousImputeEndIndex = 35; + + protected Map lastSeen = new HashMap<>(); + + @BeforeClass + public static void setUpOnce() { + // Generate the list of doubles + randomDoubles = generateUniformRandomDoubles(dataSize, min, max); + } + + protected void verifyImputation( + ImputationMethod imputation, + Map lastSeen, + double dataValue, + JsonObject imputed0, + String entity + ) { + assertTrue(imputed0.get("imputed").getAsBoolean()); + switch (imputation) { + case ZERO: + assertEquals(0, dataValue, EPSILON); + break; + case PREVIOUS: + // if we have recorded lastSeen + Double entityValue = lastSeen.get(entity); + if (entityValue != null && !areDoublesEqual(entityValue, -1)) { + assertEquals(entityValue, dataValue, EPSILON); + } + break; + case FIXED_VALUES: + assertEquals(1, dataValue, EPSILON); + break; + default: + assertTrue(false); + break; + } + } + + protected TrainResult createAndStartRealTimeDetector( + int numberOfEntities, + int trainTestSplit, + List data, + ImputationMethod imputation, + boolean hc, + long trainTimeMillis + ) throws Exception { + TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis); + List result = startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, true); + recordLastSeenFromResult(result); + + return trainResult; + } + + protected TrainResult createAndStartHistoricalDetector( + int numberOfEntities, + int trainTestSplit, + List data, + ImputationMethod imputation, + boolean hc, + long trainTimeMillis + ) throws Exception { + TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis); + List result = startHistoricalDetector(trainResult, numberOfEntities, intervalMinutes, true); + recordLastSeenFromResult(result); + + return trainResult; + } + + protected void recordLastSeenFromResult(List result) { + for (int j = 0; j < result.size(); j++) { + JsonObject source = result.get(j); + lastSeen.put(getEntity(source), extractFeatureValue(source)); + } + } + + protected TrainResult createDetector( + int numberOfEntities, + int trainTestSplit, + List data, + ImputationMethod imputation, + boolean hc, + long trainTimeMillis + ) throws Exception { + Instant trainTime = Instant.ofEpochMilli(trainTimeMillis); + + Duration windowDelay = getWindowDelay(trainTimeMillis); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis); + + RestClient client = client(); + String detectorId = createDetector(client, detector); + LOG.info("Created detector {}", detectorId); + + return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime); + } + + protected Duration getWindowDelay(long trainTimeMillis) { + /* + * AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will + * make it easier to search for results based on data end time. Otherwise, real data time and the converted + * data time from request time. + * Assume x = real data time. y= real window delay. y'= window delay in minutes. If y and y' are different, + * x + y - y' != x. + */ + long currentTime = System.currentTimeMillis(); + long windowDelayMinutes = (trainTimeMillis - currentTime) / 60000; + LOG.info("train time {}, current time {}, window delay {}", trainTimeMillis, currentTime, windowDelayMinutes); + return Duration.ofMinutes(windowDelayMinutes); + } + + protected void ingestUniformSingleFeatureData(int ingestDataSize, List data) throws Exception { + ingestUniformSingleFeatureData(ingestDataSize, data, datasetName, categoricalField); + } + + protected JsonObject createJsonObject(long timestamp, String component, double dataValue) { + return createJsonObject(timestamp, component, dataValue, categoricalField); + } + + protected abstract String genDetector( + int trainTestSplit, + long windowDelayMinutes, + boolean hc, + ImputationMethod imputation, + long trainTimeMillis + ); + + protected abstract AbstractSyntheticDataTest.GenData genData( + int trainTestSplit, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE missingMode + ) throws Exception; + + protected abstract void runTest( + long firstDataStartTime, + AbstractSyntheticDataTest.GenData dataGenerated, + Duration windowDelay, + String detectorId, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE mode, + ImputationMethod imputation, + int numberOfMissingToCheck, + boolean realTime + ); + + protected abstract double extractFeatureValue(JsonObject source); +} diff --git a/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java new file mode 100644 index 000000000..1fe3bcb6f --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java @@ -0,0 +1,357 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.NavigableSet; +import java.util.TreeMap; +import java.util.TreeSet; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +import com.google.gson.JsonObject; + +public class MissingMultiFeatureIT extends MissingIT { + + public void testSingleStream() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + // only ingest train data to avoid validation error as we use latest data time as starting point. + // otherwise, we will have too many missing points. + ingestUniformSingleFeatureData( + trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train. + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + false, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + + public void testHCFixed() throws Exception { + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA; + ImputationMethod method = ImputationMethod.FIXED_VALUES; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + // only ingest train data to avoid validation error as we use latest data time as starting point. + // otherwise, we will have too many missing points. + ingestUniformSingleFeatureData( + trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train. + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + + public void testHCPrevious() throws Exception { + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA; + ImputationMethod method = ImputationMethod.PREVIOUS; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + // only ingest train data to avoid validation error as we use latest data time as starting point. + // otherwise, we will have too many missing points. + ingestUniformSingleFeatureData( + trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train. + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + + @Override + protected String genDetector( + int trainTestSplit, + long windowDelayMinutes, + boolean hc, + ImputationMethod imputation, + long trainTimeMillis + ) { + StringBuilder sb = new StringBuilder(); + + // feature with filter so that we only get data for training and test will have missing value on this feature + String featureWithFilter = String + .format( + Locale.ROOT, + "{\n" + + " \"feature_id\": \"feature1\",\n" + + " \"feature_name\": \"feature 1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"Feature1\": {\n" + + " \"filter\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"timestamp\": {\n" + + " \"lte\": %d\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"aggregations\": {\n" + + " \"deny_max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"data\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + trainTimeMillis + ); + + // common part + sb + .append( + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_id\": \"feature2\", \"feature_name\": \"feature 2\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature2\": { \"avg\": { \"field\": \"data\" } } } }," + + featureWithFilter + + "], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"history\": %d," + ); + + if (windowDelayMinutes > 0) { + sb + .append( + String + .format( + Locale.ROOT, + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}},", + windowDelayMinutes + ) + ); + } + if (hc) { + sb.append("\"category_field\": [\"%s\"], "); + } + + switch (imputation) { + case ZERO: + sb.append("\"imputation_option\": { \"method\": \"zero\" },"); + break; + case PREVIOUS: + sb.append("\"imputation_option\": { \"method\": \"previous\" },"); + break; + case FIXED_VALUES: + sb + .append( + "\"imputation_option\": { \"method\": \"fixed_values\", \"defaultFill\": [{ \"feature_name\" : \"feature 1\", \"data\": 1 }, { \"feature_name\" : \"feature 2\", \"data\": 2 }] }," + ); + break; + } + // end + sb.append("\"schema_version\": 0}"); + + if (hc) { + return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField); + } else { + return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1); + } + } + + @Override + protected AbstractSyntheticDataTest.GenData genData( + int trainTestSplit, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE missingMode + ) throws Exception { + List data = new ArrayList<>(); + long currentTime = System.currentTimeMillis(); + long intervalMillis = intervalMinutes * 60000L; + long oldestTime = currentTime - intervalMillis * trainTestSplit / numberOfEntities; + int entityIndex = 0; + NavigableSet> missingEntities = new TreeSet<>(); + NavigableSet missingTimestamps = new TreeSet<>(); + long testStartTime = 0; + + for (int i = 0; i < randomDoubles.size(); i++) { + // we won't miss the train time (the first point triggering cold start) + if (oldestTime > currentTime && testStartTime == 0) { + LOG.info("test start time {}, index {}, current time {}", oldestTime, data.size(), currentTime); + testStartTime = oldestTime; + } + JsonObject jsonObject = createJsonObject(oldestTime, "entity" + entityIndex, randomDoubles.get(i)); + data.add(jsonObject); + entityIndex = (entityIndex + 1) % numberOfEntities; + if (entityIndex == 0) { + oldestTime += intervalMillis; + } + } + return new AbstractSyntheticDataTest.GenData(data, missingEntities, missingTimestamps, testStartTime); + } + + @Override + protected void runTest( + long firstDataStartTime, + AbstractSyntheticDataTest.GenData dataGenerated, + Duration windowDelay, + String detectorId, + int numberOfEntities, + AbstractSyntheticDataTest.MISSING_MODE mode, + ImputationMethod imputation, + int numberOfMissingToCheck, + boolean realTime + ) { + int errors = 0; + List data = dataGenerated.data; + long lastDataStartTime = data.get(data.size() - 1).get("timestamp").getAsLong(); + + long dataStartTime = firstDataStartTime + intervalMinutes * 60000; + int missingIndex = 0; + + // an entity might have missing values (e.g., at timestamp 1694713200000). + // Use a map to record the number of times we have seen them. + // data start time -> the number of entities + TreeMap entityMap = new TreeMap<>(); + + // exit when reaching last date time or we have seen at least three missing values + while (lastDataStartTime >= dataStartTime && missingIndex <= numberOfMissingToCheck) { + if (scoreOneResult( + String.valueOf(dataStartTime), + entityMap, + windowDelay, + intervalMinutes, + detectorId, + client(), + numberOfEntities + )) { + errors++; + } + + Instant begin = Instant.ofEpochMilli(dataStartTime); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + List sourceList = getRealTimeAnomalyResult(detectorId, end, numberOfEntities, client()); + + assertTrue( + String + .format( + Locale.ROOT, + "the number of results is %d at %s, expected %d ", + sourceList.size(), + end.toEpochMilli(), + numberOfEntities + ), + sourceList.size() == numberOfEntities + ); + + for (int j = 0; j < numberOfEntities; j++) { + JsonObject source = sourceList.get(j); + + double dataValue = extractFeatureValue(source); + + String entity = getEntity(source); + + JsonObject imputed1 = getImputed(source, "feature1"); + // the feature starts missing since train time. + verifyImputation(imputation, lastSeen, dataValue, imputed1, entity); + + JsonObject imputed2 = getImputed(source, "feature2"); + assertTrue(!imputed2.get("imputed").getAsBoolean()); + } + missingIndex++; + } catch (Exception e) { + errors++; + LOG.error("failed to get detection results", e); + } + + dataStartTime += intervalMinutes * 60000; + } + + assertTrue(missingIndex > numberOfMissingToCheck); + assertTrue(errors < maxError); + } + + @Override + protected double extractFeatureValue(JsonObject source) { + for (int i = 0; i < 2; i++) { + JsonObject feature = getFeature(source, i); + if (feature.get("feature_name").getAsString().equals("feature 1")) { + return feature.get("data").getAsDouble(); + } + } + throw new IllegalArgumentException("Fail to find feature 1"); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/MissingSingleFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/MissingSingleFeatureIT.java new file mode 100644 index 000000000..a8717d53d --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/MissingSingleFeatureIT.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +public class MissingSingleFeatureIT extends AbstractMissingSingleFeatureTestCase { + public static final Logger LOG = (Logger) LogManager.getLogger(MissingSingleFeatureIT.class); + + public void testSingleStream() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + false, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + + public void testHCMissingTimeStamp() throws Exception { + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP; + ImputationMethod method = ImputationMethod.PREVIOUS; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + + public void testHCMissingEntity() throws Exception { + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.MISSING_ENTITY; + ImputationMethod method = ImputationMethod.FIXED_VALUES; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java new file mode 100644 index 000000000..6b0273c0a --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +public class PreviewMissingSingleFeatureIT extends AbstractMissingSingleFeatureTestCase { + public static final Logger LOG = (Logger) LogManager.getLogger(MissingSingleFeatureIT.class); + + @SuppressWarnings("unchecked") + public void testSingleStream() throws Exception { + + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + Duration windowDelay = getWindowDelay(dataGenerated.testStartTime); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime); + + Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong()); + Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong()); + Map result = preview(detector, begin, end, client()); + // We return java.lang.IllegalArgumentException: Insufficient data for preview results. Minimum required: 400 + // But we return empty results instead. Read comments in AnomalyDetectorRunner.onFailure. + List results = (List) XContentMapValues.extractValue(result, "anomaly_result"); + assertTrue(results.size() == 0); + + } + + @SuppressWarnings("unchecked") + public void testHC() throws Exception { + + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + Duration windowDelay = getWindowDelay(dataGenerated.testStartTime); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime); + + Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong()); + Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong()); + Map result = preview(detector, begin, end, client()); + // We return java.lang.IllegalArgumentException: Insufficient data for preview results. Minimum required: 400 + // But we return empty results instead. Read comments in AnomalyDetectorRunner.onFailure. + List results = (List) XContentMapValues.extractValue(result, "anomaly_result"); + assertTrue(results.size() == 0); + + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java new file mode 100644 index 000000000..8b481e3c9 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.util.List; +import java.util.Map; + +import org.opensearch.common.xcontent.support.XContentMapValues; + +public class PreviewRuleIT extends AbstractRuleTestCase { + @SuppressWarnings("unchecked") + public void testRule() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + + String datasetName = "rule"; + int intervalMinutes = 10; + int numberOfEntities = 2; + int trainTestSplit = 100; + + TrainResult trainResult = ingestTrainData( + datasetName, + intervalMinutes, + numberOfEntities, + trainTestSplit, + true, + // ingest just enough for finish the test + (trainTestSplit + 1) * numberOfEntities + ); + + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + Map result = preview(detector, trainResult.firstDataTime, trainResult.finalDataTime, client()); + List results = (List) XContentMapValues.extractValue(result, "anomaly_result"); + assertTrue(results.size() > 100); + Map firstResult = (Map) results.get(0); + assertTrue((Double) XContentMapValues.extractValue(firstResult, "anomaly_grade") >= 0); + List feature = (List) XContentMapValues.extractValue(firstResult, "feature_data"); + Map firstFeatureValue = (Map) feature.get(0); + assertTrue((Double) XContentMapValues.extractValue(firstFeatureValue, "data") != null); + } + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/RealTimeMissingSingleFeatureModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeMissingSingleFeatureModelPerfIT.java new file mode 100644 index 000000000..0dc01d186 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeMissingSingleFeatureModelPerfIT.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import org.opensearch.timeseries.AbstractSyntheticDataTest; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; + +public class RealTimeMissingSingleFeatureModelPerfIT extends AbstractMissingSingleFeatureTestCase { + public void testSingleStream() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + false, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + true + ); + } + + public void testHCFixed() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.FIXED_VALUES; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + true + ); + } + + public void testHCPrevious() throws Exception { + int numberOfEntities = 1; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE; + ImputationMethod method = ImputationMethod.PREVIOUS; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + ingestUniformSingleFeatureData( + -1, // ingest all + dataGenerated.data + ); + + TrainResult trainResult = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime + ); + + // we allowed 25 (continuousImputeEndIndex - continuousImputeStartIndex + 1) continuous missing timestamps in shouldSkipDataPoint + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult.windowDelay, + trainResult.detectorId, + numberOfEntities, + mode, + method, + continuousImputeEndIndex - continuousImputeStartIndex + 1, + true + ); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/RuleIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java similarity index 86% rename from src/test/java/org/opensearch/ad/e2e/RuleIT.java rename to src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java index 92f733f01..650fec64d 100644 --- a/src/test/java/org/opensearch/ad/e2e/RuleIT.java +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java @@ -14,7 +14,7 @@ import com.google.gson.JsonObject; -public class RuleIT extends AbstractRuleTestCase { +public class RealTimeRuleIT extends AbstractRuleTestCase { public void testRuleWithDateNanos() throws Exception { // TODO: this test case will run for a much longer time and timeout with security enabled if (!isHttps()) { @@ -25,7 +25,7 @@ public void testRuleWithDateNanos() throws Exception { int numberOfEntities = 2; int trainTestSplit = 100; - TrainResult trainResult = ingestTrainData( + TrainResult trainResult = ingestTrainDataAndCreateDetector( datasetName, intervalMinutes, numberOfEntities, @@ -35,11 +35,12 @@ public void testRuleWithDateNanos() throws Exception { (trainTestSplit + 1) * numberOfEntities ); + startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, false); List data = trainResult.data; LOG.info("scoring data at {}", data.get(trainResult.rawDataTrainTestSplit).get("timestamp").getAsString()); // one run call will evaluate all entities within an interval - int numberEntitiesScored = findTrainTimeEntities(trainResult.rawDataTrainTestSplit, data); + int numberEntitiesScored = findGivenTimeEntities(trainResult.rawDataTrainTestSplit, data); // an entity might have missing values (e.g., at timestamp 1694713200000). // Use a map to record the number of times we have seen them. // data start time -> the number of entities @@ -47,7 +48,7 @@ public void testRuleWithDateNanos() throws Exception { // rawDataTrainTestSplit is the actual index of next test data. assertFalse( scoreOneResult( - data.get(trainResult.rawDataTrainTestSplit), + data.get(trainResult.rawDataTrainTestSplit).get("timestamp").getAsString(), entityMap, trainResult.windowDelay, intervalMinutes, @@ -64,7 +65,7 @@ public void testRuleWithDateNanos() throws Exception { Instant begin = Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)); Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); try { - List sourceList = getAnomalyResult(trainResult.detectorId, end, numberEntitiesScored, client()); + List sourceList = getRealTimeAnomalyResult(trainResult.detectorId, end, numberEntitiesScored, client()); assertTrue( String diff --git a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java new file mode 100644 index 000000000..5062fe63c --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.e2e; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; + +import org.apache.commons.lang3.tuple.Triple; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.client.RestClient; + +import com.google.gson.JsonObject; + +public class RealTimeRuleModelPerfIT extends AbstractRuleModelPerfTestCase { + static final Logger LOG = (Logger) LogManager.getLogger(RealTimeRuleModelPerfIT.class); + + public void testRule() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + // there are 8 entities in the data set. Each one needs 1500 rows as training data. + Map minPrecision = new HashMap<>(); + minPrecision.put("Phoenix", 0.5); + minPrecision.put("Scottsdale", 0.5); + Map minRecall = new HashMap<>(); + minRecall.put("Phoenix", 0.9); + minRecall.put("Scottsdale", 0.6); + verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20); + } + } + + public void verifyRule( + String datasetName, + int intervalMinutes, + int numberOfEntities, + int trainTestSplit, + Map minPrecision, + Map minRecall, + int maxError + ) throws Exception { + verifyRule(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, minPrecision, minRecall, maxError, false); + } + + public void verifyRule( + String datasetName, + int intervalMinutes, + int numberOfEntities, + int trainTestSplit, + Map minPrecision, + Map minRecall, + int maxError, + boolean useDateNanos + ) throws Exception { + + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + Map>> anomalies = getAnomalyWindowsMap(labelFileName); + + TrainResult trainResult = ingestTrainDataAndCreateDetector( + datasetName, + intervalMinutes, + numberOfEntities, + trainTestSplit, + useDateNanos + ); + startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, false); + + Triple, Integer, Map>> results = getTestResults( + trainResult.detectorId, + trainResult.data, + trainResult.rawDataTrainTestSplit, + intervalMinutes, + anomalies, + client(), + numberOfEntities, + trainResult.windowDelay + ); + verifyTestResults(results, anomalies, minPrecision, minRecall, maxError); + } + + private Triple, Integer, Map>> getTestResults( + String detectorId, + List data, + int rawTrainTestSplit, + int intervalMinutes, + Map>> anomalies, + RestClient client, + int numberOfEntities, + Duration windowDelay + ) throws Exception { + + Map res = new HashMap<>(); + int errors = 0; + // an entity might have missing values (e.g., at timestamp 1694713200000). + // Use a map to record the number of times we have seen them. + // data start time -> the number of entities + TreeMap entityMap = new TreeMap<>(); + for (int i = rawTrainTestSplit; i < data.size(); i++) { + if (scoreOneResult( + data.get(i).get("timestamp").getAsString(), + entityMap, + windowDelay, + intervalMinutes, + detectorId, + client, + numberOfEntities + )) { + errors++; + } + } + + // hash set to dedup + Map> foundWindow = new HashMap<>(); + + // Iterate over the TreeMap in ascending order of keys + for (Map.Entry entry : entityMap.entrySet()) { + String beginTimeStampAsString = entry.getKey(); + int entitySize = entry.getValue(); + Instant begin = Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + List sourceList = getRealTimeAnomalyResult(detectorId, end, entitySize, client); + + analyzeResults(anomalies, res, foundWindow, beginTimeStampAsString, entitySize, begin, sourceList); + } catch (Exception e) { + errors++; + LOG.error("failed to get detection results", e); + } + } + return Triple.of(res, errors, foundWindow); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/RuleModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/RuleModelPerfIT.java deleted file mode 100644 index 208cfbe48..000000000 --- a/src/test/java/org/opensearch/ad/e2e/RuleModelPerfIT.java +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ad.e2e; - -import java.io.File; -import java.io.FileReader; -import java.nio.charset.Charset; -import java.time.Duration; -import java.time.Instant; -import java.time.format.DateTimeFormatter; -import java.time.temporal.ChronoUnit; -import java.util.AbstractMap.SimpleEntry; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; -import java.util.TreeMap; - -import org.apache.commons.lang3.tuple.Triple; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.core.Logger; -import org.opensearch.client.RestClient; - -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; - -public class RuleModelPerfIT extends AbstractRuleTestCase { - static final Logger LOG = (Logger) LogManager.getLogger(RuleModelPerfIT.class); - - public void testRule() throws Exception { - // TODO: this test case will run for a much longer time and timeout with security enabled - if (!isHttps()) { - disableResourceNotFoundFaultTolerence(); - // there are 8 entities in the data set. Each one needs 1500 rows as training data. - Map minPrecision = new HashMap<>(); - minPrecision.put("Phoenix", 0.5); - minPrecision.put("Scottsdale", 0.5); - Map minRecall = new HashMap<>(); - minRecall.put("Phoenix", 0.9); - minRecall.put("Scottsdale", 0.6); - verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20); - } - } - - private void verifyTestResults( - Triple, Integer, Map>> testResults, - Map>> anomalies, - Map minPrecision, - Map minRecall, - int maxError - ) { - Map resultMap = testResults.getLeft(); - Map> foundWindows = testResults.getRight(); - - for (Entry entry : resultMap.entrySet()) { - String entity = entry.getKey(); - double[] testResultsArray = entry.getValue(); - double positives = testResultsArray[0]; - double truePositives = testResultsArray[1]; - - // precision = predicted anomaly points that are true / predicted anomaly points - double precision = positives > 0 ? truePositives / positives : 0; - double minPrecisionValue = minPrecision.getOrDefault(entity, .4); - assertTrue( - String - .format( - Locale.ROOT, - "precision expected at least %f but got %f. positives %f, truePositives %f", - minPrecisionValue, - precision, - positives, - truePositives - ), - precision >= minPrecisionValue - ); - - // recall = windows containing predicted anomaly points / total anomaly windows - int anomalyWindow = anomalies.getOrDefault(entity, new ArrayList<>()).size(); - int foundWindowSize = foundWindows.getOrDefault(entity, new HashSet<>()).size(); - double recall = anomalyWindow > 0 ? foundWindowSize * 1.0d / anomalyWindow : 0; - double minRecallValue = minRecall.getOrDefault(entity, .7); - assertTrue( - String - .format( - Locale.ROOT, - "recall should be at least %f but got %f. anomalyWindow %d, foundWindowSize %d ", - minRecallValue, - recall, - anomalyWindow, - foundWindowSize - ), - recall >= minRecallValue - ); - - LOG.info("Entity {}, Precision: {}, Window recall: {}", entity, precision, recall); - } - - int errors = testResults.getMiddle(); - assertTrue(errors <= maxError); - } - - public void verifyRule( - String datasetName, - int intervalMinutes, - int numberOfEntities, - int trainTestSplit, - Map minPrecision, - Map minRecall, - int maxError - ) throws Exception { - verifyRule(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, minPrecision, minRecall, maxError, false); - } - - public void verifyRule( - String datasetName, - int intervalMinutes, - int numberOfEntities, - int trainTestSplit, - Map minPrecision, - Map minRecall, - int maxError, - boolean useDateNanos - ) throws Exception { - - String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); - Map>> anomalies = getAnomalyWindowsMap(labelFileName); - - TrainResult trainResult = ingestTrainData(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos); - - Triple, Integer, Map>> results = getTestResults( - trainResult.detectorId, - trainResult.data, - trainResult.rawDataTrainTestSplit, - intervalMinutes, - anomalies, - client(), - numberOfEntities, - trainResult.windowDelay - ); - verifyTestResults(results, anomalies, minPrecision, minRecall, maxError); - } - - private Triple, Integer, Map>> getTestResults( - String detectorId, - List data, - int rawTrainTestSplit, - int intervalMinutes, - Map>> anomalies, - RestClient client, - int numberOfEntities, - Duration windowDelay - ) throws Exception { - - Map res = new HashMap<>(); - int errors = 0; - // an entity might have missing values (e.g., at timestamp 1694713200000). - // Use a map to record the number of times we have seen them. - // data start time -> the number of entities - TreeMap entityMap = new TreeMap<>(); - for (int i = rawTrainTestSplit; i < data.size(); i++) { - if (scoreOneResult(data.get(i), entityMap, windowDelay, intervalMinutes, detectorId, client, numberOfEntities)) { - errors++; - } - } - - // hash set to dedup - Map> foundWindow = new HashMap<>(); - - // Iterate over the TreeMap in ascending order of keys - for (Map.Entry entry : entityMap.entrySet()) { - String beginTimeStampAsString = entry.getKey(); - int entitySize = entry.getValue(); - Instant begin = Instant.ofEpochMilli(Long.parseLong(beginTimeStampAsString)); - Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); - try { - List sourceList = getAnomalyResult(detectorId, end, entitySize, client); - - assertTrue( - String - .format( - Locale.ROOT, - "the number of results is %d at %s, expected %d ", - sourceList.size(), - beginTimeStampAsString, - entitySize - ), - sourceList.size() == entitySize - ); - for (int j = 0; j < entitySize; j++) { - JsonObject source = sourceList.get(j); - double anomalyGrade = getAnomalyGrade(source); - assertTrue("anomalyGrade cannot be negative", anomalyGrade >= 0); - if (anomalyGrade > 0) { - String entity = getEntity(source); - double[] entityResult = res.computeIfAbsent(entity, key -> new double[] { 0, 0 }); - // positive++ - entityResult[0]++; - Instant anomalyTime = getAnomalyTime(source, begin); - LOG.info("Found anomaly: entity {}, time {} result {}.", entity, anomalyTime, source); - int anomalyWindow = isAnomaly(anomalyTime, anomalies.getOrDefault(entity, new ArrayList<>())); - if (anomalyWindow != -1) { - LOG.info("True anomaly: entity {}, time {}.", entity, begin); - // truePositives++; - entityResult[1]++; - Set window = foundWindow.computeIfAbsent(entity, key -> new HashSet<>()); - window.add(anomalyWindow); - } - } - } - } catch (Exception e) { - errors++; - LOG.error("failed to get detection results", e); - } - } - return Triple.of(res, errors, foundWindow); - } - - public Map>> getAnomalyWindowsMap(String labelFileName) throws Exception { - JsonObject jsonObject = JsonParser - .parseReader(new FileReader(new File(getClass().getResource(labelFileName).toURI()), Charset.defaultCharset())) - .getAsJsonObject(); - - Map>> map = new HashMap<>(); - for (Map.Entry entry : jsonObject.entrySet()) { - List> anomalies = new ArrayList<>(); - JsonElement value = entry.getValue(); - if (value.isJsonArray()) { - for (JsonElement elem : value.getAsJsonArray()) { - JsonElement beginElement = elem.getAsJsonArray().get(0); - JsonElement endElement = elem.getAsJsonArray().get(1); - Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(beginElement.getAsString())); - Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(endElement.getAsString())); - anomalies.add(new SimpleEntry<>(begin, end)); - } - } - map.put(entry.getKey(), anomalies); - } - return map; - } -} diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java index 75855a1a8..50f067e0d 100644 --- a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java @@ -157,7 +157,7 @@ private double[] getTestResults( Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); try { - List sourceList = getAnomalyResult(detectorId, end, 1, client); + List sourceList = getRealTimeAnomalyResult(detectorId, end, 1, client); assertTrue("anomalyGrade cannot be negative", sourceList.size() == 1); double anomalyGrade = getAnomalyGrade(sourceList.get(0)); assertTrue("anomalyGrade cannot be negative", anomalyGrade >= 0); diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index 092fa982a..f2a33964d 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -184,12 +184,12 @@ public void getColdStartData_returnExpectedToListener( }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); if (latestTime != null) { doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(4); listener.onResponse(samples); return null; }) .when(searchFeatureDao) - .getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), eq(AnalysisType.AD), eq(false), any(ActionListener.class)); } ActionListener> listener = mock(ActionListener.class); @@ -237,7 +237,7 @@ public void getColdStartData_throwToListener_onQueryCreationError() throws Excep }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); doThrow(IOException.class) .when(searchFeatureDao) - .getFeatureSamplesForPeriods(eq(detector), any(), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(), eq(AnalysisType.AD), eq(false), any(ActionListener.class)); ActionListener> listener = mock(ActionListener.class); featureManager.getColdStartData(detector, listener); @@ -317,8 +317,8 @@ private void getPreviewFeaturesTemplate(List> samplesResults, ActionListener>> listener = null; - if (args[3] instanceof ActionListener) { - listener = (ActionListener>>) args[3]; + if (args[4] instanceof ActionListener) { + listener = (ActionListener>>) args[4]; } if (querySuccess) { @@ -328,7 +328,7 @@ private void getPreviewFeaturesTemplate(List> samplesResults, } return null; - }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), eq(AnalysisType.AD), any()); + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), eq(AnalysisType.AD), eq(false), any()); ActionListener listener = mock(ActionListener.class); featureManager.getPreviewFeatures(detector, start, end, listener); @@ -434,7 +434,7 @@ private void setupSearchFeatureDaoForGetCurrentFeatures( AtomicBoolean isPreQuery = new AtomicBoolean(true); doAnswer(invocation -> { - ActionListener>> daoListener = invocation.getArgument(3); + ActionListener>> daoListener = invocation.getArgument(4); if (isPreQuery.get()) { isPreQuery.set(false); daoListener.onResponse(preQueryResponse); @@ -448,7 +448,7 @@ private void setupSearchFeatureDaoForGetCurrentFeatures( return null; }) .when(searchFeatureDao) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); } private Object[] getCurrentFeaturesTestData_whenAfterQueryResultsFormFullShingle() { @@ -487,7 +487,7 @@ public void getCurrentFeatures_returnExpectedProcessedFeatures_whenAfterQueryRes // Start test Optional listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertTrue(listenerResponse.isPresent()); double[] actualProcessedFeatures = listenerResponse.get(); @@ -519,7 +519,7 @@ public void getCurrentFeatures_returnExpectedProcessedFeatures_IOException( // Start test Exception listenerResponse = getCurrentFeaturesOnFailure(detector, start, end); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertTrue(listenerResponse instanceof IOException); } @@ -565,7 +565,7 @@ public void getCurrentFeatures_returnExpectedProcessedFeatures_whenAfterQueryRes // Start test Optional listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertTrue(listenerResponse.isPresent()); } @@ -605,7 +605,7 @@ public void getCurrentFeatures_returnNoProcessedOrUnprocessedFeatures_whenMissin // Start test Optional listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertFalse(listenerResponse.isPresent()); } @@ -635,7 +635,7 @@ public void getCurrentFeatures_returnNoProcessedFeatures_whenAfterQueryResultsCa // Start test Optional listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertTrue(listenerResponse.isPresent()); } @@ -665,7 +665,7 @@ public void getCurrentFeatures_returnExceptionToListener_whenQueryThrowsIOExcept ActionListener> listener = mock(ActionListener.class); featureManager.getCurrentFeatures(detector, testStartTime, testEndTime, AnalysisType.AD, listener); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); verify(listener).onFailure(any(IOException.class)); } @@ -693,12 +693,24 @@ public void getCurrentFeatures_returnExpectedFeatures_cacheMissingData( // first call to cache missing points featureManager.getCurrentFeatures(detector, firstStartTime, firstEndTime, AnalysisType.AD, mock(ActionListener.class)); verify(searchFeatureDao, times(1)) - .getFeatureSamplesForPeriods(eq(detector), argThat(list -> list.size() == 1), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods( + eq(detector), + argThat(list -> list.size() == 1), + eq(AnalysisType.AD), + eq(true), + any(ActionListener.class) + ); // second call should only fetch current point even if previous points missing Optional listenerResponse = getCurrentFeatures(detector, secondStartTime, secondEndTime); verify(searchFeatureDao, times(2)) - .getFeatureSamplesForPeriods(eq(detector), argThat(list -> list.size() == 1), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods( + eq(detector), + argThat(list -> list.size() == 1), + eq(AnalysisType.AD), + eq(true), + any(ActionListener.class) + ); assertTrue(listenerResponse.isPresent()); } @@ -759,7 +771,7 @@ public void getCurrentFeatures_returnExpectedFeatures_withTimeJitterUpToHalfInte // Start test Optional listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) - .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), any(ActionListener.class)); + .getFeatureSamplesForPeriods(eq(detector), any(List.class), eq(AnalysisType.AD), eq(true), any(ActionListener.class)); assertTrue(listenerResponse.isPresent()); } diff --git a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java index 637733067..983d3a8b8 100644 --- a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java +++ b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java @@ -224,6 +224,14 @@ private Map createMapping() { roles_mapping.put("fields", Collections.singletonMap("keyword", Collections.singletonMap("type", "keyword"))); user_nested_mapping.put("roles", roles_mapping); mappings.put(CommonName.USER_FIELD, user_mapping); + + Map imputed_mapping = new HashMap<>(); + imputed_mapping.put("type", "nested"); + mappings.put(AnomalyResult.FEATURE_IMPUTED, imputed_mapping); + Map imputed_nested_mapping = new HashMap<>(); + imputed_mapping.put(CommonName.PROPERTIES, imputed_nested_mapping); + imputed_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + imputed_nested_mapping.put("imputed", Collections.singletonMap("type", "boolean")); return mappings; } diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index af90fc16b..cc70f6e34 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -27,6 +27,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.time.Instant; @@ -197,7 +198,7 @@ public void setup() { descriptor.setAttribution(attributionVec); descriptor.setTotalUpdates(numSamples); descriptor.setRelevantAttribution(new double[] { 0, 0, 0, 0, 0 }); - when(trcf.process(any(), anyLong())).thenReturn(descriptor); + when(trcf.process(any(), anyLong(), any())).thenReturn(descriptor); ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); @@ -363,7 +364,9 @@ public void getRcfResult_returnExpectedToListener() { expectedValuesList, likelihood, threshold, - numTrees + numTrees, + point, + null ); verify(listener).onResponse(eq(expected)); @@ -802,7 +805,7 @@ public void maintenance_returnExpectedToListener_doNothing() { } @Test - public void getPreviewResults_returnNoAnomalies_forNoAnomalies() { + public void getPreviewResults_returnNoAnomalies_forNoAnomalies() throws IOException { int numPoints = 1000; double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); List> timeRanges = IntStream @@ -811,14 +814,17 @@ public void getPreviewResults_returnNoAnomalies_forNoAnomalies() { .collect(Collectors.toList()); Features features = new Features(timeRanges, points); - List results = modelManager.getPreviewResults(features, shingleSize, 0.0001); + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getShingleSize()).thenReturn(shingleSize); + when(detector.getRecencyEmphasis()).thenReturn(10000); + List results = modelManager.getPreviewResults(features, detector); assertEquals(numPoints, results.size()); assertTrue(results.stream().noneMatch(r -> r.getGrade() > 0)); } @Test - public void getPreviewResults_returnAnomalies_forLastAnomaly() { + public void getPreviewResults_returnAnomalies_forLastAnomaly() throws IOException { int numPoints = 1000; double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); points[points.length - 1] = new double[] { 1. }; @@ -828,7 +834,10 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { .collect(Collectors.toList()); Features features = new Features(timeRanges, points); - List results = modelManager.getPreviewResults(features, shingleSize, 0.0001); + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getShingleSize()).thenReturn(shingleSize); + when(detector.getRecencyEmphasis()).thenReturn(10000); + List results = modelManager.getPreviewResults(features, detector); assertEquals(numPoints, results.size()); assertTrue(results.stream().limit(numPoints - 1).noneMatch(r -> r.getGrade() > 0)); @@ -836,9 +845,13 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { } @Test(expected = IllegalArgumentException.class) - public void getPreviewResults_throwIllegalArgument_forInvalidInput() { + public void getPreviewResults_throwIllegalArgument_forInvalidInput() throws IOException { Features features = new Features(new ArrayList>(), new double[0][0]); - modelManager.getPreviewResults(features, shingleSize, 0.0001); + + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getShingleSize()).thenReturn(shingleSize); + when(detector.getRecencyEmphasis()).thenReturn(10000); + modelManager.getPreviewResults(features, detector); } @Test @@ -998,7 +1011,9 @@ public void score_with_trcf() { descriptor.getExpectedValuesList(), descriptor.getLikelihoodOfValues(), descriptor.getThreshold(), - numTrees + numTrees, + this.point, + null ), result ); @@ -1015,7 +1030,7 @@ public void score_throw() { when(rcf.getShingleSize()).thenReturn(8); when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); - doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong()); + doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong(), any()); when(this.modelState.getSamples()) .thenReturn(new ArrayDeque<>(Arrays.asList(new Sample(this.point, Instant.now(), Instant.now())))); modelManager.score(new Sample(this.point, Instant.now(), Instant.now()), this.modelId, this.modelState, anomalyDetector); diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index 422c1be94..62d5c8f97 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -311,6 +311,7 @@ public void testParseAnomalyDetectorWithEmptyUiMetadata() throws IOException { public void testInvalidShingleSize() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -321,7 +322,7 @@ public void testInvalidShingleSize() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -332,7 +333,7 @@ public void testInvalidShingleSize() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -346,6 +347,7 @@ public void testInvalidShingleSize() throws Exception { public void testNullDetectorName() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -356,7 +358,7 @@ public void testNullDetectorName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -367,7 +369,7 @@ public void testNullDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -381,6 +383,7 @@ public void testNullDetectorName() throws Exception { public void testBlankDetectorName() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -391,7 +394,7 @@ public void testBlankDetectorName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -402,7 +405,7 @@ public void testBlankDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -416,6 +419,7 @@ public void testBlankDetectorName() throws Exception { public void testNullTimeField() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -426,7 +430,7 @@ public void testNullTimeField() throws Exception { randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -437,7 +441,7 @@ public void testNullTimeField() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -451,6 +455,7 @@ public void testNullTimeField() throws Exception { public void testNullIndices() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -461,7 +466,7 @@ public void testNullIndices() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), null, - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -472,7 +477,7 @@ public void testNullIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -486,6 +491,7 @@ public void testNullIndices() throws Exception { public void testEmptyIndices() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -496,7 +502,7 @@ public void testEmptyIndices() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -507,7 +513,7 @@ public void testEmptyIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -521,6 +527,7 @@ public void testEmptyIndices() throws Exception { public void testNullDetectionInterval() throws Exception { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); TestHelpers .assertFailWith( ValidationException.class, @@ -531,7 +538,7 @@ public void testNullDetectionInterval() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), null, TestHelpers.randomIntervalTimeConfiguration(), @@ -542,7 +549,7 @@ public void testNullDetectionInterval() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -556,6 +563,7 @@ public void testNullDetectionInterval() throws Exception { public void testInvalidRecency() { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); ValidationException exception = expectThrows( ValidationException.class, () -> new AnomalyDetector( @@ -565,7 +573,7 @@ public void testInvalidRecency() { randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), TestHelpers.randomIntervalTimeConfiguration(), @@ -576,7 +584,7 @@ public void testInvalidRecency() { null, null, null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), -1, randomIntBetween(1, 256), randomIntBetween(1, 1000), @@ -591,6 +599,7 @@ public void testInvalidRecency() { public void testInvalidDetectionInterval() { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); ValidationException exception = expectThrows( ValidationException.class, () -> new AnomalyDetector( @@ -600,7 +609,7 @@ public void testInvalidDetectionInterval() { randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), TestHelpers.randomIntervalTimeConfiguration(), @@ -611,7 +620,7 @@ public void testInvalidDetectionInterval() { null, null, null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), null, // emphasis is not customized randomIntBetween(1, 256), randomIntBetween(1, 1000), @@ -626,6 +635,7 @@ public void testInvalidDetectionInterval() { public void testInvalidWindowDelay() { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, () -> new AnomalyDetector( @@ -635,7 +645,7 @@ public void testInvalidWindowDelay() { randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), new IntervalTimeConfiguration(-1, ChronoUnit.MINUTES), @@ -646,7 +656,7 @@ public void testInvalidWindowDelay() { null, null, null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), null, // emphasis is not customized randomIntBetween(1, 256), randomIntBetween(1, 1000), @@ -676,6 +686,7 @@ public void testEmptyFeatures() throws IOException { public void testGetShingleSize() throws IOException { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -683,7 +694,7 @@ public void testGetShingleSize() throws IOException { randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -694,7 +705,7 @@ public void testGetShingleSize() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -709,6 +720,7 @@ public void testGetShingleSize() throws IOException { public void testGetShingleSizeReturnsDefaultValue() throws IOException { int seasonalityIntervals = randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2); Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -727,7 +739,7 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), seasonalityIntervals, randomIntBetween(1, 1000), @@ -757,7 +769,7 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), null, null, randomIntBetween(1, 1000), @@ -789,7 +801,7 @@ public void testNullFeatureAttributes() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(0), + null, randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -821,7 +833,7 @@ public void testValidateResultIndex() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(0), + null, randomIntBetween(1, 10000), randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE * TimeSeriesSettings.SEASONALITY_TO_SHINGLE_RATIO), randomIntBetween(1, 1000), @@ -864,11 +876,11 @@ public void testParseAnomalyDetector_withCustomIndex_withCustomResultIndexMinSiz + "\"time_field\":\"HmdFH\",\"indices\":[\"ffsBF\"],\"filter_query\":{\"bool\":{\"filter\":[{\"exists\":" + "{\"field\":\"value\",\"boost\":1}}],\"adjust_pure_negative\":true,\"boost\":1}},\"window_delay\":" + "{\"period\":{\"interval\":2,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":-512063255," - + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":false," + + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":true," + "\"aggregation_query\":{\"XzewX\":{\"value_count\":{\"field\":\"ok\"}}}}],\"recency_emphasis\":3342," + "\"history\":62,\"last_update_time\":1717192049845,\"category_field\":[\"Tcqcb\"],\"result_index\":" + "\"opensearch-ad-plugin-result-test\",\"imputation_option\":{\"method\":\"FIXED_VALUES\",\"defaultFill\"" - + ":[],\"integerSensitive\":false},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + + ":[{\"feature_name\":\"eYYCM\", \"data\": 3}]},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + "{\"interval\":5,\"unit\":\"Minutes\"}},\"detector_type\":\"MULTI_ENTITY\",\"rules\":[],\"result_index_min_size\":1500}"; AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertEquals(1500, (int) parsedDetector.getCustomResultIndexMinSize()); @@ -891,11 +903,11 @@ public void testParseAnomalyDetector_withCustomIndex_withCustomResultIndexMinAge + "\"time_field\":\"HmdFH\",\"indices\":[\"ffsBF\"],\"filter_query\":{\"bool\":{\"filter\":[{\"exists\":" + "{\"field\":\"value\",\"boost\":1}}],\"adjust_pure_negative\":true,\"boost\":1}},\"window_delay\":" + "{\"period\":{\"interval\":2,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":-512063255," - + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":false," + + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":true," + "\"aggregation_query\":{\"XzewX\":{\"value_count\":{\"field\":\"ok\"}}}}],\"recency_emphasis\":3342," + "\"history\":62,\"last_update_time\":1717192049845,\"category_field\":[\"Tcqcb\"],\"result_index\":" + "\"opensearch-ad-plugin-result-test\",\"imputation_option\":{\"method\":\"FIXED_VALUES\",\"defaultFill\"" - + ":[],\"integerSensitive\":false},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + + ":[{\"feature_name\":\"eYYCM\", \"data\": 3}]},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + "{\"interval\":5,\"unit\":\"Minutes\"}},\"detector_type\":\"MULTI_ENTITY\",\"rules\":[],\"result_index_min_age\":7}"; AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertEquals(7, (int) parsedDetector.getCustomResultIndexMinAge()); @@ -906,11 +918,11 @@ public void testParseAnomalyDetector_withCustomIndex_withCustomResultIndexTTL() + "\"time_field\":\"HmdFH\",\"indices\":[\"ffsBF\"],\"filter_query\":{\"bool\":{\"filter\":[{\"exists\":" + "{\"field\":\"value\",\"boost\":1}}],\"adjust_pure_negative\":true,\"boost\":1}},\"window_delay\":" + "{\"period\":{\"interval\":2,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":-512063255," - + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":false," + + "\"feature_attributes\":[{\"feature_id\":\"OTYJs\",\"feature_name\":\"eYYCM\",\"feature_enabled\":true," + "\"aggregation_query\":{\"XzewX\":{\"value_count\":{\"field\":\"ok\"}}}}],\"recency_emphasis\":3342," + "\"history\":62,\"last_update_time\":1717192049845,\"category_field\":[\"Tcqcb\"],\"result_index\":" + "\"opensearch-ad-plugin-result-test\",\"imputation_option\":{\"method\":\"FIXED_VALUES\",\"defaultFill\"" - + ":[],\"integerSensitive\":false},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + + ":[{\"feature_name\":\"eYYCM\", \"data\": 3}]},\"suggested_seasonality\":64,\"detection_interval\":{\"period\":" + "{\"interval\":5,\"unit\":\"Minutes\"}},\"detector_type\":\"MULTI_ENTITY\",\"rules\":[],\"result_index_ttl\":30}"; AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertEquals(30, (int) parsedDetector.getCustomResultIndexTTL()); diff --git a/src/test/java/org/opensearch/ad/model/FeatureImputedTests.java b/src/test/java/org/opensearch/ad/model/FeatureImputedTests.java new file mode 100644 index 000000000..f90f73510 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/FeatureImputedTests.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +import java.io.IOException; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class FeatureImputedTests extends OpenSearchTestCase { + + private FeatureImputed featureImputed; + private String featureId = "feature_1"; + private Boolean imputed = true; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + featureImputed = new FeatureImputed(featureId, imputed); + } + + public void testParseFeatureImputed() throws IOException { + String jsonString = TestHelpers.xContentBuilderToString(featureImputed.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + + // Parse the JSON content + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); // move to the first token + + // Call the parse method + FeatureImputed parsedFeatureImputed = FeatureImputed.parse(parser); + + // Verify the parsed object + assertNotNull("Parsed FeatureImputed should not be null", parsedFeatureImputed); + assertEquals("Feature ID should match", featureId, parsedFeatureImputed.getFeatureId()); + assertEquals("Imputed value should match", imputed, parsedFeatureImputed.isImputed()); + } + + public void testWriteToAndReadFrom() throws IOException { + // Serialize the object + BytesStreamOutput output = new BytesStreamOutput(); + featureImputed.writeTo(output); + + // Deserialize the object + StreamInput streamInput = output.bytes().streamInput(); + FeatureImputed readFeatureImputed = new FeatureImputed(streamInput); + + // Verify the deserialized object + MatcherAssert.assertThat("Feature ID should match", readFeatureImputed.getFeatureId(), equalTo(featureId)); + MatcherAssert.assertThat("Imputed value should match", readFeatureImputed.isImputed(), equalTo(imputed)); + + // verify equals/hashCode + MatcherAssert.assertThat("FeatureImputed should match", featureImputed, equalTo(featureImputed)); + MatcherAssert.assertThat("FeatureImputed should match", featureImputed, not(equalTo(streamInput))); + MatcherAssert.assertThat("FeatureImputed should match", readFeatureImputed, equalTo(featureImputed)); + MatcherAssert.assertThat("FeatureImputed should match", readFeatureImputed.hashCode(), equalTo(featureImputed.hashCode())); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index 115c41093..d4dd1878c 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -48,6 +48,7 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; @@ -102,6 +103,7 @@ public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { FeatureRequest request, request2, request3; ClusterSettings clusterSettings; ADStats adStats; + ADInferencer inferencer; @Override public void setUp() throws Exception { @@ -150,6 +152,7 @@ public void setUp() throws Exception { }; adStats = new ADStats(statsMap); + inferencer = new ADInferencer(modelManager, adStats, checkpoint, coldstartQueue, resultWriteStrategy, cacheProvider, threadPool); // Integer.MAX_VALUE makes a huge heap worker = new ADCheckpointReadWorker( @@ -171,12 +174,10 @@ public void setUp() throws Exception { checkpoint, coldstartQueue, nodeStateManager, - anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats, - resultWriteStrategy + inferencer ); request = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity, null); @@ -247,14 +248,14 @@ private void regularTestSetUp(RegularSetUpConfig config) { public void testRegular() { regularTestSetUp(new RegularSetUpConfig.Builder().build()); - verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), any(), anyString(), any(), any(), any()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } public void testCannotLoadModel() { regularTestSetUp(new RegularSetUpConfig.Builder().canHostModel(false).build()); - verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), any(), anyString(), any(), any(), any()); verify(checkpointWriteQueue, times(1)).write(any(), anyBoolean(), any()); } @@ -262,7 +263,7 @@ public void testNoFullModel() { regularTestSetUp(new RegularSetUpConfig.Builder().fullModel(false).build()); // even though saveResult is called, the actual won't happen as the rcf score is 0 // we have the guard condition at the beginning of saveResult method. - verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), any(), anyString(), any(), any(), any()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } @@ -552,12 +553,10 @@ public void testRemoveUnusedQueues() { checkpoint, coldstartQueue, nodeStateManager, - anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats, - resultWriteStrategy + inferencer ); regularTestSetUp(new RegularSetUpConfig.Builder().build()); @@ -604,12 +603,10 @@ public void testSettingUpdatable() { checkpoint, coldstartQueue, nodeStateManager, - anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats, - resultWriteStrategy + inferencer ); List requests = new ArrayList<>(); @@ -657,12 +654,10 @@ public void testOpenCircuitBreaker() { checkpoint, coldstartQueue, nodeStateManager, - anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats, - resultWriteStrategy + inferencer ); List requests = new ArrayList<>(); @@ -805,7 +800,7 @@ public void testFailToScore() { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); - // anyString won't match null. That's why we use any() at position 4 instead of anyString. + // anyString won't match null. That's why we use any() at position 2 instead of anyString. doThrow(new IllegalArgumentException()).when(modelManager).getResult(any(), any(), anyString(), any(), any()); List requests = new ArrayList<>(); @@ -813,7 +808,7 @@ public void testFailToScore() { worker.putAll(requests); verify(modelManager, times(1)).getResult(any(), any(), anyString(), any(), any()); - verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), anyString()); + verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), any(), anyString(), any(), any(), any()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); verify(coldstartQueue, times(1)).put(any()); Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index 933436368..264b19660 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -47,6 +47,7 @@ import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TimeSeriesTask; @@ -196,6 +197,8 @@ public static Response createAnomalyDetector( boolean historical ) throws Exception { Instant now = Instant.now(); + List featureList = ImmutableList + .of(TestHelpers.randomFeature(randomAlphaOfLength(5), valueField, aggregationMethod, true)); AnomalyDetector detector = new AnomalyDetector( randomAlphaOfLength(10), randomLong(), @@ -204,7 +207,7 @@ public static Response createAnomalyDetector( randomAlphaOfLength(30), timeField, ImmutableList.of(indexName), - ImmutableList.of(TestHelpers.randomFeature(randomAlphaOfLength(5), valueField, aggregationMethod, true)), + featureList, filterQuery == null ? TestHelpers.randomQuery("{\"match_all\":{\"boost\":1}}") : TestHelpers.randomQuery(filterQuery), new IntervalTimeConfiguration(detectionIntervalInMinutes, ChronoUnit.MINUTES), new IntervalTimeConfiguration(windowDelayIntervalInMinutes, ChronoUnit.MINUTES), @@ -215,7 +218,7 @@ public static Response createAnomalyDetector( categoryFields, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(1), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index ea3d448b0..85db01ad8 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -132,6 +132,7 @@ private AnomalyDetector createIndexAndGetAnomalyDetector(String indexName, List< public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); AnomalyDetector detectorDuplicateName = new AnomalyDetector( AnomalyDetector.NO_ID, randomLong(), @@ -139,7 +140,7 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { randomAlphaOfLength(5), randomAlphaOfLength(5), detector.getIndices(), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -150,7 +151,7 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -261,7 +262,7 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { ImmutableList.of(randomAlphaOfLength(5)), detector.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -327,7 +328,7 @@ public void testUpdateAnomalyDetector() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -398,7 +399,7 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { null, detector1.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -446,7 +447,7 @@ public void testUpdateAnomalyDetectorNameToNew() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -500,7 +501,7 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -871,7 +872,7 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption((int) expectedFeatures), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index 0e1078968..1e60e0b4a 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -134,9 +134,6 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul if (!TaskState.RUNNING.name().equals(adTaskProfile.getTask().getState())) { adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } - // if (adTaskProfile.getTotalEntitiesCount() == null) { - // adTaskProfile = (ADTaskProfile) waitUntilEntityCountAvailable(detectorId).get(0); - // } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 46cd0e619..867514e6a 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -215,7 +215,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { public void testSettingsGetValue() { Settings settings = Settings.builder().put("plugins.anomaly_detection.request_timeout", "42s").build(); assertEquals(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(42)); - assertEquals(LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(10)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(60)); settings = Settings.builder().put("plugins.anomaly_detection.max_anomaly_detectors", 99).build(); assertEquals(AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(99)); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index 3d489747b..0bac122d6 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -911,7 +911,7 @@ public void testMaintainRunningRealtimeTasks() { @SuppressWarnings("unchecked") public void testStartHistoricalAnalysisWithNoOwningNode() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of()); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(randomFeature(true))); DateRange detectionDateRange = TestHelpers.randomDetectionDateRange(); User user = null; int availableTaskSlots = randomIntBetween(1, 10); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 6878f8d53..99762c5b1 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -69,7 +69,8 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; @@ -78,6 +79,7 @@ import org.opensearch.ad.ratelimit.ADColdStartWorker; import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.task.ADTaskManager; @@ -93,7 +95,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; @@ -121,6 +122,7 @@ import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.stats.StatNames; import org.opensearch.timeseries.stats.TimeSeriesStat; import org.opensearch.timeseries.stats.suppliers.CounterSupplier; @@ -164,6 +166,8 @@ public class AnomalyResultTests extends AbstractTimeSeriesTest { private ADTaskManager adTaskManager; private ADCheckpointReadWorker checkpointReadQueue; private ADCacheProvider cacheProvider; + private ADInferencer inferencer; + private ADColdStartWorker coldStartWorker; @BeforeClass public static void setUpBeforeClass() { @@ -251,7 +255,9 @@ public void setUp() throws Exception { expectedValuesList, likelihood, threshold, - 30 + 30, + new double[2], + null ) ); return null; @@ -309,6 +315,7 @@ public void setUp() throws Exception { put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -347,6 +354,17 @@ public void setUp() throws Exception { cacheProvider = mock(ADCacheProvider.class); when(cacheProvider.get()).thenReturn(mock(ADPriorityCache.class)); + + coldStartWorker = mock(ADColdStartWorker.class); + inferencer = new ADInferencer( + normalModelManager, + adStats, + mock(ADCheckpointDao.class), + coldStartWorker, + mock(ADSaveResultStrategy.class), + cacheProvider, + threadPool + ); } @Override @@ -385,11 +403,8 @@ public void testNormal() throws IOException, InterruptedException { cacheProvider, stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - resultWriteWorker, - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -504,14 +519,6 @@ public void sendRequest( when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handler on testNodes[1] - // new RCFResultTransportAction( - // new ActionFilters(Collections.emptySet()), - // testNodes[1].transportService, - // normalModelManager, - // adCircuitBreakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( testNodes[1].transportService, new ActionFilters(Collections.emptySet()), @@ -519,11 +526,8 @@ public void sendRequest( mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); TransportService realTransportService = testNodes[0].transportService; @@ -571,14 +575,6 @@ public void testInsufficientCapacityExceptionDuringColdStart() { .thenReturn(Optional.of(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))); // These constructors register handler in transport service - // new RCFResultTransportAction( - // new ActionFilters(Collections.emptySet()), - // transportService, - // rcfManager, - // adCircuitBreakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( transportService, new ActionFilters(Collections.emptySet()), @@ -586,11 +582,8 @@ public void testInsufficientCapacityExceptionDuringColdStart() { mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - rcfManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -620,17 +613,29 @@ public void testInsufficientCapacityExceptionDuringColdStart() { } @SuppressWarnings("unchecked") - public void testInsufficientCapacityExceptionDuringRestoringModel() { + public void testInsufficientCapacityExceptionDuringRestoringModel() throws InterruptedException { + ADModelManager badModelManager = mock(ADModelManager.class); + doThrow(new NullPointerException()).when(badModelManager).getResult(any(), any(), any(), any(), any()); - ADModelManager rcfManager = mock(ADModelManager.class); + inferencer = new ADInferencer( + badModelManager, + adStats, + mock(ADCheckpointDao.class), + coldStartWorker, + mock(ADSaveResultStrategy.class), + cacheProvider, + threadPool + ); ADPriorityCache adPriorityCache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(adPriorityCache); when(adPriorityCache.get(anyString(), any())).thenReturn(mock(ModelState.class)); - doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) - .when(rcfManager) - .getResult(any(), any(), any(), any(), any()); + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return null; + }).when(coldStartWorker).put(any(FeatureRequest.class)); // These constructors register handler in transport service new ADSingleStreamResultTransportAction( @@ -640,11 +645,8 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { cacheProvider, stateManager, mock(ADCheckpointReadWorker.class), - rcfManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -670,7 +672,9 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - verify(stateManager, times(1)).setException(eq(adID), any(LimitExceededException.class)); + inProgress.await(30, TimeUnit.SECONDS); + // null pointer exception caused re-cold start + verify(coldStartWorker, times(1)).put(any(FeatureRequest.class)); } private TransportResponseHandler rcfResponseHandler(TransportResponseHandler handler) { @@ -761,14 +765,6 @@ public void sendRequest( when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handlers on testNodes[1] ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); - // new RCFResultTransportAction( - // actionFilters, - // testNodes[1].transportService, - // normalModelManager, - // adCircuitBreakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( testNodes[1].transportService, actionFilters, @@ -776,11 +772,8 @@ public void sendRequest( mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(actionFilters, testNodes[1].transportService, normalModelManager); @@ -819,14 +812,6 @@ public void testCircuitBreaker() { when(breakerService.isOpen()).thenReturn(true); // These constructors register handler in transport service - // new RCFResultTransportAction( - // new ActionFilters(Collections.emptySet()), - // transportService, - // normalModelManager, - // breakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( transportService, new ActionFilters(Collections.emptySet()), @@ -834,11 +819,8 @@ public void testCircuitBreaker() { mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -916,11 +898,8 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); AnomalyResultTransportAction action = new AnomalyResultTransportAction( @@ -1003,14 +982,6 @@ public void testMute() { public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { // These constructors register handler in transport service - // new RCFResultTransportAction( - // new ActionFilters(Collections.emptySet()), - // transportService, - // normalModelManager, - // adCircuitBreakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( transportService, new ActionFilters(Collections.emptySet()), @@ -1018,11 +989,8 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); @@ -1239,14 +1207,6 @@ public ColdStartConfig build() { @SuppressWarnings("unchecked") private void setUpColdStart(ThreadPool mockThreadPool, ColdStartConfig config) { - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(4); - listener.onResponse(Optional.empty()); - return null; - }) - .when(featureQuery) - .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), eq(AnalysisType.AD), any(ActionListener.class)); - doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); if (config.getCheckpointException == null) { @@ -1271,7 +1231,9 @@ private void setUpColdStart(ThreadPool mockThreadPool, ColdStartConfig config) { .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), eq(AnalysisType.AD), any(ActionListener.class)); ADCacheProvider cacheProvider = mock(ADCacheProvider.class); - when(cacheProvider.get()).thenReturn(mock(ADPriorityCache.class)); + ADPriorityCache priorityCache = mock(ADPriorityCache.class); + when(cacheProvider.get()).thenReturn(priorityCache); + when(priorityCache.get(any(), any())).thenReturn(null); // register action handler new ADSingleStreamResultTransportAction( @@ -1281,11 +1243,8 @@ private void setUpColdStart(ThreadPool mockThreadPool, ColdStartConfig config) { cacheProvider, stateManager, checkpointReadQueue, - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); } @@ -1430,14 +1389,6 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde when(hackedClusterService.state()).thenReturn(blockedClusterState); // These constructors register handler in transport service - // new RCFResultTransportAction( - // new ActionFilters(Collections.emptySet()), - // transportService, - // normalModelManager, - // adCircuitBreakerService, - // hashRing, - // adStats - // ); new ADSingleStreamResultTransportAction( transportService, new ActionFilters(Collections.emptySet()), @@ -1445,11 +1396,8 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde mock(ADCacheProvider.class), stateManager, mock(ADCheckpointReadWorker.class), - normalModelManager, - mock(ADIndexManagement.class), - mock(ADResultWriteWorker.class), - adStats, - mock(ADColdStartWorker.class) + inferencer, + threadPool ); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -1617,13 +1565,17 @@ public void testColdStartEndRunExceptionNow() { verify(featureQuery, never()).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); } - @SuppressWarnings({ "unchecked" }) - public void testColdStartBecauseFailtoGetCheckpoint() { + public void testColdStartBecauseFailtoGetCheckpoint() throws InterruptedException { ThreadPool mockThreadPool = mock(ThreadPool.class); setUpColdStart( mockThreadPool, new ColdStartConfig.Builder().getCheckpointException(new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME)).build() ); + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return null; + }).when(checkpointReadQueue).put(any()); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1648,6 +1600,8 @@ public void testColdStartBecauseFailtoGetCheckpoint() { action.doExecute(null, request, listener); AnomalyResultResponse response = listener.actionGet(10000L); + + inProgress.await(30, TimeUnit.SECONDS); assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); verify(checkpointReadQueue, times(1)).put(any()); } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index f8208215f..b1deb5025 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -219,7 +219,7 @@ private AnomalyDetector randomDetector(List indices, List featu null, null, null, - TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -249,7 +249,7 @@ private AnomalyDetector randomHCDetector(List indices, List fea ImmutableList.of(categoryField), null, null, - TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index e2b211a8d..e35e85c87 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -55,6 +55,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.ml.ADColdStart; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; @@ -135,6 +136,7 @@ public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { ClusterService clusterService; ADStats adStats; ADSaveResultStrategy resultSaver; + ADInferencer inferencer; @BeforeClass public static void setUpBeforeClass() { @@ -261,10 +263,11 @@ public void setUp() throws Exception { adStats = new ADStats(statsMap); resultSaver = new ADSaveResultStrategy(1, resultWriteQueue); + inferencer = new ADInferencer(manager, adStats, checkpointDao, entityColdStartQueue, resultSaver, provider, threadPool); + entityResult = new EntityADResultTransportAction( actionFilters, transportService, - manager, adCircuitBreakerService, provider, stateManager, @@ -272,9 +275,7 @@ public void setUp() throws Exception { checkpointReadQueue, coldEntityQueue, threadPool, - entityColdStartQueue, - adStats, - resultSaver + inferencer ); // timeout in 60 seconds @@ -388,10 +389,10 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { public void testFailToScore() { ADModelManager spyModelManager = spy(manager); doThrow(new IllegalArgumentException()).when(spyModelManager).getResult(any(), any(), anyString(), any(), any()); + inferencer = new ADInferencer(spyModelManager, adStats, checkpointDao, entityColdStartQueue, resultSaver, provider, threadPool); entityResult = new EntityADResultTransportAction( actionFilters, transportService, - spyModelManager, adCircuitBreakerService, provider, stateManager, @@ -399,9 +400,7 @@ public void testFailToScore() { checkpointReadQueue, coldEntityQueue, threadPool, - entityColdStartQueue, - adStats, - resultSaver + inferencer ); PlainActionFuture future = PlainActionFuture.newFuture(); diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java index 6e93ef01a..913fc64c0 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java @@ -78,7 +78,7 @@ public void testNullDetectorIdAndTaskAction() throws IOException { null, randomUser(), null, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(null), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index 931f3f42d..81903b375 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -68,6 +68,8 @@ import org.opensearch.ad.caching.ADCacheProvider; import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADInferencer; import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; @@ -175,6 +177,7 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private ADPriorityCache entityCache; private ADTaskManager adTaskManager; private ADSaveResultStrategy resultSaver; + private ADInferencer inferencer; @BeforeClass public static void setUpBeforeClass() { @@ -319,6 +322,16 @@ public void setUp() throws Exception { attrs3 = new HashMap<>(); attrs3.put(serviceField, app0); attrs3.put(hostField, server3); + + inferencer = new ADInferencer( + normalModelManager, + adStats, + mock(ADCheckpointDao.class), + entityColdStartQueue, + resultSaver, + provider, + threadPool + ); } @Override @@ -413,7 +426,6 @@ private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[nodeIndex].transportService, - normalModelManager, adCircuitBreakerService, provider, nodeStateManager, @@ -421,9 +433,7 @@ private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) checkpointReadQueue, coldEntityQueue, threadPool, - entityColdStartQueue, - adStats, - resultSaver + inferencer ); when(normalModelManager.getResult(any(), any(), any(), any(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); @@ -782,7 +792,6 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, - normalModelManager, openBreaker, provider, spyStateManager, @@ -790,9 +799,7 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { checkpointReadQueue, coldEntityQueue, threadPool, - entityColdStartQueue, - adStats, - resultSaver + inferencer ); CountDownLatch inProgress = new CountDownLatch(1); @@ -966,7 +973,6 @@ public void testCacheSelection() throws IOException, InterruptedException { new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, - normalModelManager, adCircuitBreakerService, provider, stateManager, @@ -974,9 +980,7 @@ public void testCacheSelection() throws IOException, InterruptedException { checkpointReadQueue, coldEntityQueue, threadPool, - entityColdStartQueue, - adStats, - resultSaver + inferencer ); CountDownLatch modelNodeInProgress = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java index eee68c4ae..d32c3222f 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java @@ -12,7 +12,6 @@ package org.opensearch.ad.transport; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -173,7 +172,7 @@ public void onFailure(Exception e) { } }; - doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt(), anyInt()); + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), any()); doAnswer(responseMock -> { Long startTime = responseMock.getArgument(1); @@ -373,7 +372,7 @@ public void onFailure(Exception e) { Assert.assertTrue(false); } }; - doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt(), anyInt()); + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), any()); doAnswer(responseMock -> { Long startTime = responseMock.getArgument(1); diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index a3ed6ee7c..82b9723ac 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -139,7 +139,9 @@ public void testNormal() { expectedValuesList, likelihood, threshold, - forestSize + forestSize, + new double[] { 0 }, + null ) ); return null; @@ -312,7 +314,9 @@ public void testCircuitBreaker() { expectedValuesList, likelihood, threshold, - 30 + 30, + new double[] { 0 }, + null ) ); return null; diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java index 7ba4680e7..80cef5a15 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java @@ -15,6 +15,7 @@ import java.net.URL; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.List; import java.util.Locale; import org.junit.Test; @@ -377,6 +378,7 @@ private void testValidateAnomalyDetectorWithCustomResultIndex(boolean resultInde @Test public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOException { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); AnomalyDetector anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -384,7 +386,7 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept randomAlphaOfLength(5), timeField, ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -395,7 +397,7 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -424,6 +426,7 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept @Test public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOException { Feature feature = TestHelpers.randomFeature(); + List featureList = ImmutableList.of(feature); AnomalyDetector anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), @@ -431,7 +434,7 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept randomAlphaOfLength(5), timeField, ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -442,7 +445,7 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index 32e9e14e9..98daeb1d9 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -228,7 +228,8 @@ private AnomalyResult wrongAnomalyResult() { null, null, null, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); } } diff --git a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java index 66af227ee..dd31136ad 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java @@ -61,7 +61,7 @@ public class ForecasterTests extends AbstractTimeSeriesTest { Integer customResultIndexTTL = null; public void testForecasterConstructor() { - ImputationOption imputationOption = TestHelpers.randomImputationOption(0); + ImputationOption imputationOption = TestHelpers.randomImputationOption(features); Forecaster forecaster = new Forecaster( forecasterId, @@ -135,7 +135,7 @@ public void testForecasterConstructorWithNullForecastInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -173,7 +173,7 @@ public void testNegativeInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -211,7 +211,7 @@ public void testMaxCategoryFieldsLimits() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -249,7 +249,7 @@ public void testBlankName() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -287,7 +287,7 @@ public void testInvalidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -324,7 +324,7 @@ public void testValidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), @@ -359,7 +359,7 @@ public void testInvalidHorizon() { user, resultIndex, horizon, - TestHelpers.randomImputationOption(0), + TestHelpers.randomImputationOption(features), recencyEmphasis, seasonality, randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java index 43e011dba..5e5dee049 100644 --- a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java +++ b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java @@ -325,6 +325,49 @@ public void testSuggestOneMinute() throws Exception { int historySuggestions = ((Integer) responseMap.get("history")); assertEquals(37, historySuggestions); + + // case 4: no feature is ok + forecasterDef = "{\n" + + " \"name\": \"Second-Test-Detector-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": 20,\n" + + " \"unit\": \"SECONDS\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 4,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " }\n" + + "}"; + formattedForecaster = String.format(Locale.ROOT, forecasterDef, SYNTHETIC_DATASET_NAME); + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, SUGGEST_INTERVAL_URI), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + responseMap = entityAsMap(response); + suggestions = (Map) ((Map) responseMap.get("interval")).get("period"); + assertEquals(1, (int) suggestions.get("interval")); + assertEquals("Minutes", suggestions.get("unit")); } public void testSuggestTenMinute() throws Exception { @@ -515,6 +558,134 @@ public void testSuggestSparseData() throws Exception { assertEquals(0, responseMap.size()); } + /** + * Test data interval is larger than 1 hr and we fail to suggest + */ + public void testFailToSuggest() throws Exception { + int trainTestSplit = 100; + String categoricalField = "componentName"; + GenData dataGenerated = genUniformSingleFeatureData( + 70, + trainTestSplit, + 1, + categoricalField, + MISSING_MODE.NO_MISSING_DATA, + -1, + -1, + 50 + ); + ingestUniformSingleFeatureData(trainTestSplit, dataGenerated.data, UNIFORM_DATASET_NAME, categoricalField); + + // case 1: IntervalCalculation.findMinimumInterval cannot find any data point in the last 40 points and return 1 minute instead. + // We keep searching and find nothing below 1 hr and then return. + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"sum1\",\n" + + " \"feature_name\": \"sum1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"sum1\": {\n" + + " \"sum\": {\n" + + " \"field\": \"data\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": 20,\n" + + " \"unit\": \"SECONDS\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24\n" + + "}"; + + String formattedForecaster = String.format(Locale.ROOT, forecasterDef, UNIFORM_DATASET_NAME); + + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, SUGGEST_INTERVAL_URI), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + assertEquals("Suggest forecaster interval failed", RestStatus.OK, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + assertEquals(0, responseMap.size()); + + // case 2: IntervalCalculation.findMinimumInterval find an interval larger than 1 hr by going through the last 240 points. + // findMinimumInterval returns null and we stop searching further. + forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"sum1\",\n" + + " \"feature_name\": \"sum1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"sum1\": {\n" + + " \"sum\": {\n" + + " \"field\": \"data\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": 20,\n" + + " \"unit\": \"SECONDS\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"history\": 240\n" + + "}"; + + formattedForecaster = String.format(Locale.ROOT, forecasterDef, UNIFORM_DATASET_NAME); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, SUGGEST_INTERVAL_URI), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + assertEquals("Suggest forecaster interval failed", RestStatus.OK, TestHelpers.restStatus(response)); + responseMap = entityAsMap(response); + assertEquals(0, responseMap.size()); + } + public void testValidate() throws Exception { loadSyntheticData(200); // case 1: forecaster interval is not set diff --git a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java index c10ada512..405176b44 100644 --- a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java @@ -17,7 +17,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.NavigableSet; +import java.util.Random; +import java.util.TreeSet; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; @@ -29,6 +33,7 @@ import org.opensearch.client.WarningsHandler; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AbstractSyntheticDataTest.MISSING_MODE; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -39,6 +44,35 @@ import com.google.gson.stream.JsonReader; public class AbstractSyntheticDataTest extends ODFERestTestCase { + public enum MISSING_MODE { + MISSING_TIMESTAMP, // missing all entities in a timestamps + MISSING_ENTITY, // missing single entity, + NO_MISSING_DATA, // no missing data + CONTINUOUS_IMPUTE, // vs random missing as above + } + + public static class GenData { + public List data; + // record missing entities and its timestamp in test data + public NavigableSet> missingEntities; + // record missing timestamps in test data + public NavigableSet missingTimestamps; + public long testStartTime; + + public GenData( + List data, + NavigableSet> missingEntities, + NavigableSet missingTimestamps, + long testStartTime + ) { + super(); + this.data = data; + this.missingEntities = missingEntities; + this.missingTimestamps = missingTimestamps; + this.testStartTime = testStartTime; + } + } + public static final Logger LOG = (Logger) LogManager.getLogger(AbstractSyntheticDataTest.class); public static final String SYNTHETIC_DATA_MAPPING = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; @@ -49,6 +83,8 @@ public class AbstractSyntheticDataTest extends ODFERestTestCase { + "\"componentName\": { \"type\": \"keyword\"} } } }"; public static final String SYNTHETIC_DATASET_NAME = "synthetic"; public static final String RULE_DATASET_NAME = "rule"; + public static final String UNIFORM_DATASET_NAME = "uniform"; + public static int batchSize = 1000; /** * In real time AD, we mute a node for a detector if that node keeps returning @@ -101,9 +137,7 @@ public static void waitAllSyncheticDataIngested(int expectedSize, String dataset // Make sure all of the test data has been ingested JsonArray hits = getHits(client, request); LOG.info("Latest synthetic data:" + hits); - if (hits != null - && hits.size() == 1 - && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + if (hits != null && hits.size() == 1 && isIdExpected(expectedSize, hits)) { break; } else { request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); @@ -113,6 +147,17 @@ public static void waitAllSyncheticDataIngested(int expectedSize, String dataset } while (maxWaitCycles-- >= 0); } + private static boolean isIdExpected(int expectedSize, JsonArray hits) { + // we won't have more than 3 entities with the same timestamp to make the test fast + int delta = 3; + for (int i = 0; i < hits.size(); i++) { + if (expectedSize - 1 <= hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong() + delta) { + return true; + } + } + return false; + } + public static JsonArray getHits(RestClient client, Request request) throws IOException { Response response = client.performRequest(request); return parseHits(response); @@ -247,4 +292,213 @@ public static boolean canBeParsedAsLong(String str) { } } + public static List generateUniformRandomDoubles(int size, double min, double max) { + List randomDoubles = new ArrayList<>(size); + Random random = new Random(0); + + for (int i = 0; i < size; i++) { + double randomValue = min + (max - min) * random.nextDouble(); + randomDoubles.add(randomValue); + } + + return randomDoubles; + } + + protected JsonObject createJsonObject(long timestamp, String component, double dataValue, String categoricalField) { + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("timestamp", timestamp); + jsonObject.addProperty(categoricalField, component); + jsonObject.addProperty("data", dataValue); + return jsonObject; + } + + public GenData genUniformSingleFeatureData( + int intervalMinutes, + int trainTestSplit, + int numberOfEntities, + String categoricalField, + MISSING_MODE missingMode, + int continuousImputeStartIndex, + int continuousImputeEndIndex, + List randomDoubles + ) { + List data = new ArrayList<>(); + long currentTime = System.currentTimeMillis(); + long intervalMillis = intervalMinutes * 60000L; + long timestampMillis = currentTime - intervalMillis * trainTestSplit / numberOfEntities; + LOG.info("begin timestamp: {}", timestampMillis); + int entityIndex = 0; + NavigableSet> missingEntities = new TreeSet<>(); + NavigableSet missingTimestamps = new TreeSet<>(); + long testStartTime = 0; + Random random = new Random(); + + for (int i = 0; i < randomDoubles.size();) { + // we won't miss the train time (the first point triggering cold start) + if (timestampMillis > currentTime && testStartTime == 0) { + LOG.info("test start time {}, index {}, current time {}", timestampMillis, data.size(), currentTime); + testStartTime = timestampMillis; + + for (int j = 0; j < numberOfEntities; j++) { + JsonObject jsonObject = createJsonObject( + timestampMillis, + "entity" + entityIndex, + randomDoubles.get(i++), + categoricalField + ); + entityIndex = (entityIndex + 1) % numberOfEntities; + data.add(jsonObject); + } + timestampMillis += intervalMillis; + + continue; + } + + if (shouldSkipDataPoint( + missingMode, + entityIndex, + testStartTime, + timestampMillis, + random, + intervalMillis, + continuousImputeStartIndex, + continuousImputeEndIndex + )) { + if (timestampMillis > currentTime) { + if (missingMode == MISSING_MODE.MISSING_TIMESTAMP || missingMode == MISSING_MODE.CONTINUOUS_IMPUTE) { + missingTimestamps.add(timestampMillis); + } else if (missingMode == MISSING_MODE.MISSING_ENTITY) { + missingEntities.add(Pair.of(timestampMillis, "entity" + entityIndex)); + entityIndex = (entityIndex + 1) % numberOfEntities; + if (entityIndex == 0) { + timestampMillis += intervalMillis; + } + } + } + + if (missingMode == MISSING_MODE.MISSING_TIMESTAMP || missingMode == MISSING_MODE.CONTINUOUS_IMPUTE) { + timestampMillis += intervalMillis; + } + } else { + JsonObject jsonObject = createJsonObject(timestampMillis, "entity" + entityIndex, randomDoubles.get(i), categoricalField); + data.add(jsonObject); + entityIndex = (entityIndex + 1) % numberOfEntities; + if (entityIndex == 0) { + timestampMillis += intervalMillis; + } + } + + i++; + } + LOG + .info( + "begin timestamp: {}, end timestamp: {}", + data.get(0).get("timestamp").getAsLong(), + data.get(data.size() - 1).get("timestamp").getAsLong() + ); + return new GenData(data, missingEntities, missingTimestamps, testStartTime); + } + + public GenData genUniformSingleFeatureData( + int intervalMinutes, + int trainTestSplit, + int numberOfEntities, + String categoricalField, + MISSING_MODE missingMode, + int continuousImputeStartIndex, + int continuousImputeEndIndex, + int dataSize + ) { + List randomDoubles = generateUniformRandomDoubles(dataSize, 200, 300); + + return genUniformSingleFeatureData( + intervalMinutes, + trainTestSplit, + numberOfEntities, + categoricalField, + missingMode, + continuousImputeStartIndex, + continuousImputeEndIndex, + randomDoubles + ); + } + + protected boolean shouldSkipDataPoint( + AbstractSyntheticDataTest.MISSING_MODE missingMode, + int entityIndex, + long testStartTime, + long currentTime, + Random random, + long intervalMillis, + int continuousImputeStartIndex, + int continuousImputeEndIndex + ) { + if (testStartTime == 0 || missingMode == AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA) { + return false; + } + if (missingMode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_TIMESTAMP && entityIndex == 0) { + return random.nextDouble() > 0.5; + } else if (missingMode == AbstractSyntheticDataTest.MISSING_MODE.MISSING_ENTITY) { + return random.nextDouble() > 0.5; + } else if (missingMode == AbstractSyntheticDataTest.MISSING_MODE.CONTINUOUS_IMPUTE && entityIndex == 0) { + long delta = (currentTime - testStartTime) / intervalMillis; + // start missing in a range + return delta >= continuousImputeStartIndex && delta <= continuousImputeEndIndex; + } + return false; + } + + protected void bulkIndexData(List data, String datasetName, RestClient client, String mapping, int ingestDataSize) + throws Exception { + createIndex(datasetName, client, mapping); + StringBuilder bulkRequestBuilder = new StringBuilder(); + LOG.info("data size {}", data.size()); + int count = 0; + int pickedIngestSize = Math.min(ingestDataSize, data.size()); + LOG.info("ingest size {}", pickedIngestSize); + for (int i = 0; i < pickedIngestSize; i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + count++; + if (count >= batchSize || i == pickedIngestSize - 1) { + count = 0; + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + } + } + + waitAllSyncheticDataIngested(data.size(), datasetName, client); + LOG.info("data ingestion complete"); + } + + protected void ingestUniformSingleFeatureData(int ingestDataSize, List data, String datasetName, String categoricalField) + throws Exception { + + RestClient client = client(); + + String mapping = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\":" + + "\"date\"" + + "}," + + " \"data\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoricalField + ); + + if (ingestDataSize <= 0) { + bulkIndexData(data, datasetName, client, mapping, data.size()); + } else { + bulkIndexData(data, datasetName, client, mapping, ingestDataSize); + } + } } diff --git a/src/test/java/org/opensearch/timeseries/ODFERestTestCase.java b/src/test/java/org/opensearch/timeseries/ODFERestTestCase.java index eccc2ee53..092ca210d 100644 --- a/src/test/java/org/opensearch/timeseries/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/timeseries/ODFERestTestCase.java @@ -80,6 +80,7 @@ * ODFE integration test base class to support both security disabled and enabled ODFE cluster. */ public abstract class ODFERestTestCase extends OpenSearchRestTestCase { + private static final Logger LOG = (Logger) LogManager.getLogger(ODFERestTestCase.class); protected boolean isHttps() { diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index e633f2734..88bc10942 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -45,7 +45,6 @@ import java.util.Set; import java.util.concurrent.Callable; import java.util.function.Consumer; -import java.util.stream.DoubleStream; import java.util.stream.IntStream; import org.apache.hc.core5.http.ContentType; @@ -136,7 +135,6 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; import org.opensearch.test.ClusterServiceUtils; -import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.constant.CommonMessages; @@ -332,7 +330,7 @@ public static AnomalyDetector randomAnomalyDetector( categoryFields, user, null, - TestHelpers.randomImputationOption(features == null ? 0 : (int) features.stream().filter(Feature::getEnabled).count()), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE * 2), randomIntBetween(1, 1000), @@ -384,7 +382,7 @@ public static AnomalyDetector randomDetector( categoryFields, null, resultIndex, - TestHelpers.randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -428,6 +426,7 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( List categoryFields, String resultIndex ) throws IOException { + List features = ImmutableList.of(randomFeature(true)); return new AnomalyDetector( detectorId, randomLong(), @@ -435,7 +434,7 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( randomAlphaOfLength(30), timeField, indices, - ImmutableList.of(randomFeature(true)), + features, randomQuery(), randomIntervalTimeConfiguration(), new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), @@ -446,7 +445,7 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( categoryFields, randomUser(), resultIndex, - TestHelpers.randomImputationOption(1), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -484,7 +483,7 @@ public static AnomalyDetector randomAnomalyDetector(String timefield, String ind null, randomUser(), null, - TestHelpers.randomImputationOption(features == null ? 0 : (int) features.stream().filter(Feature::getEnabled).count()), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -514,7 +513,7 @@ public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOE null, randomUser(), null, - TestHelpers.randomImputationOption(0), + null, randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -532,6 +531,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval, boolean hcDetector) throws IOException { List categoryField = hcDetector ? ImmutableList.of(randomAlphaOfLength(5)) : null; Feature feature = randomFeature(); + List featureList = ImmutableList.of(feature); return new AnomalyDetector( randomAlphaOfLength(10), randomLong(), @@ -539,7 +539,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(feature), + featureList, randomQuery(), interval, randomIntervalTimeConfiguration(), @@ -550,7 +550,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption(feature.getEnabled() ? 1 : 0), + TestHelpers.randomImputationOption(featureList), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -588,6 +588,9 @@ public static class AnomalyDetectorBuilder { private ImputationOption imputationOption = null; private List rules = null; + // transform decay (reverse of recencyEmphasis) has to be [0, 1). So we cannot use 1. + private int recencyEmphasis = randomIntBetween(2, 10000); + public static AnomalyDetectorBuilder newInstance(int numberOfFeatures) throws IOException { return new AnomalyDetectorBuilder(numberOfFeatures); } @@ -686,8 +689,8 @@ public AnomalyDetectorBuilder setResultIndex(String resultIndex) { return this; } - public AnomalyDetectorBuilder setImputationOption(ImputationMethod method, Optional defaultFill, boolean integerSentive) { - this.imputationOption = new ImputationOption(method, defaultFill, integerSentive); + public AnomalyDetectorBuilder setImputationOption(ImputationMethod method, Map defaultFill) { + this.imputationOption = new ImputationOption(method, defaultFill); return this; } @@ -696,6 +699,11 @@ public AnomalyDetectorBuilder setRules(List rules) { return this; } + public AnomalyDetectorBuilder setRecencyEmphasis(int recencyEmphasis) { + this.recencyEmphasis = recencyEmphasis; + return this; + } + public AnomalyDetector build() { return new AnomalyDetector( detectorId, @@ -716,8 +724,7 @@ public AnomalyDetector build() { user, resultIndex, imputationOption, - // transform decay has to be [0, 1). So we cannot use 1. - randomIntBetween(2, 10000), + recencyEmphasis, randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE * 2), // make history intervals at least TimeSeriesSettings.NUM_MIN_SAMPLES. // Otherwise, tests like EntityColdStarterTests.testTwoSegments may fail @@ -735,6 +742,7 @@ public AnomalyDetector build() { public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval, boolean hcDetector, boolean featureEnabled) throws IOException { List categoryField = hcDetector ? ImmutableList.of(randomAlphaOfLength(5)) : null; + List features = ImmutableList.of(randomFeature(featureEnabled)); return new AnomalyDetector( randomAlphaOfLength(10), randomLong(), @@ -742,7 +750,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio randomAlphaOfLength(30), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), - ImmutableList.of(randomFeature(featureEnabled)), + features, randomQuery(), interval, randomIntervalTimeConfiguration(), @@ -753,7 +761,7 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption(featureEnabled ? 1 : 0), + TestHelpers.randomImputationOption(features), randomIntBetween(1, 10000), randomInt(TimeSeriesSettings.MAX_SHINGLE_SIZE / 2), randomIntBetween(1, 1000), @@ -963,7 +971,8 @@ public static AnomalyResult randomAnomalyDetectResult(double score, String error relavantAttribution, pastValues, expectedValuesList, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); } @@ -1044,7 +1053,8 @@ public static AnomalyResult randomHCADAnomalyDetectResult( relavantAttribution, pastValues, expectedValuesList, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); } @@ -1683,15 +1693,34 @@ public static ClusterState createClusterState() { return clusterState; } - public static ImputationOption randomImputationOption(int featureSize) { - double[] defaultFill = DoubleStream.generate(OpenSearchTestCase::randomDouble).limit(featureSize).toArray(); - ImputationOption fixedValue = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); - ImputationOption linear = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), false); - ImputationOption linearIntSensitive = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), true); - ImputationOption zero = new ImputationOption(ImputationMethod.ZERO); - ImputationOption previous = new ImputationOption(ImputationMethod.PREVIOUS); + public static Map randomFixedValue(List features) { + Map map = new HashMap<>(); + if (features == null) { + return map; + } + + Random random = new Random(); + + for (int i = 0; i < features.size(); i++) { + if (features.get(i).getEnabled()) { + double randomValue = random.nextDouble(); // generate a random double value + map.put(features.get(i).getName(), randomValue); + } + } + + return map; + } + + public static ImputationOption randomImputationOption(List features) { + Map randomFixedValue = randomFixedValue(features); + + List options = new ArrayList<>(); + if (randomFixedValue.size() != 0) { + options.add(new ImputationOption(ImputationMethod.FIXED_VALUES, randomFixedValue)); + } - List options = List.of(fixedValue, linear, linearIntSensitive, zero, previous); + options.add(new ImputationOption(ImputationMethod.ZERO)); + options.add(new ImputationOption(ImputationMethod.PREVIOUS)); // Select a random option int randomIndex = Randomness.get().nextInt(options.size()); @@ -1729,7 +1758,7 @@ public static class ForecasterBuilder { description = randomAlphaOfLength(20); timeField = randomAlphaOfLength(5); indices = ImmutableList.of(randomAlphaOfLength(10)); - features = ImmutableList.of(randomFeature()); + features = ImmutableList.of(randomFeature(true)); filterQuery = randomQuery(); forecastInterval = randomIntervalTimeConfiguration(); windowDelay = randomIntervalTimeConfiguration(); @@ -1741,7 +1770,7 @@ public static class ForecasterBuilder { user = randomUser(); resultIndex = null; horizon = randomIntBetween(1, 20); - imputationOption = randomImputationOption((int) features.stream().filter(Feature::getEnabled).count()); + imputationOption = randomImputationOption(features); customResultIndexMinSize = null; customResultIndexMinAge = null; customResultIndexTTL = null; @@ -1894,6 +1923,7 @@ public Forecaster build() { public static Forecaster randomForecaster() throws IOException { Feature feature = randomFeature(); + List featureList = ImmutableList.of(feature); return new Forecaster( randomAlphaOfLength(10), randomLong(), @@ -1901,7 +1931,7 @@ public static Forecaster randomForecaster() throws IOException { randomAlphaOfLength(20), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(10)), - ImmutableList.of(feature), + featureList, randomQuery(), randomIntervalTimeConfiguration(), randomIntervalTimeConfiguration(), @@ -1913,7 +1943,7 @@ public static Forecaster randomForecaster() throws IOException { randomUser(), null, randomIntBetween(1, 20), - randomImputationOption(feature.getEnabled() ? 1 : 0), + randomImputationOption(featureList), randomIntBetween(1, 1000), randomIntBetween(1, 128), randomIntBetween(1, 1000), diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java index 9adb57ed9..df1d4a2e9 100644 --- a/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java @@ -6,8 +6,11 @@ package org.opensearch.timeseries.dataprocessor; import java.io.IOException; -import java.util.Optional; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import org.junit.BeforeClass; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.bytes.BytesReference; @@ -19,14 +22,46 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + public class ImputationOptionTests extends OpenSearchTestCase { + private static ObjectMapper mapper; + private static Map map; + private static String xContent; + + @BeforeClass + public static void setUpOnce() { + mapper = new ObjectMapper(); + double[] defaultFill = { 1.0, 2.0, 3.0 }; + map = new HashMap<>(); + map.put("a", defaultFill[0]); + map.put("b", defaultFill[1]); + map.put("c", defaultFill[2]); + + xContent = "{" + + "\"method\":\"FIXED_VALUES\"," + + "\"defaultFill\":[{\"feature_name\":\"a\", \"data\":1.0},{\"feature_name\":\"b\", \"data\":2.0},{\"feature_name\":\"c\", \"data\":3.0}]}"; + } + + private Map randomMap(double[] defaultFill) { + Map map = new HashMap<>(); + + for (int i = 0; i < defaultFill.length; i++) { + String randomKey = UUID.randomUUID().toString(); // generate a random UUID string as the key + map.put(randomKey, defaultFill[i]); + } + + return map; + } public void testStreamInputAndOutput() throws IOException { // Prepare the data to be read by the StreamInput object. ImputationMethod method = ImputationMethod.PREVIOUS; double[] defaultFill = { 1.0, 2.0, 3.0 }; + Map map1 = randomMap(defaultFill); - ImputationOption option = new ImputationOption(method, Optional.of(defaultFill), false); + ImputationOption option = new ImputationOption(method, map1); // Write the ImputationOption to the StreamOutput. BytesStreamOutput out = new BytesStreamOutput(); @@ -39,26 +74,25 @@ public void testStreamInputAndOutput() throws IOException { // Check that the created ImputationOption has the correct values. assertEquals(method, inOption.getMethod()); - assertArrayEquals(defaultFill, inOption.getDefaultFill().get(), 1e-6); + assertEquals(map1, inOption.getDefaultFill()); } public void testToXContent() throws IOException { - double[] defaultFill = { 1.0, 2.0, 3.0 }; - ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); - String xContent = "{" + "\"method\":\"FIXED_VALUES\"," + "\"defaultFill\":[1.0,2.0,3.0],\"integerSensitive\":false" + "}"; + ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, map); XContentBuilder builder = imputationOption.toXContent(JsonXContent.contentBuilder(), ToXContent.EMPTY_PARAMS); String actualJson = BytesReference.bytes(builder).utf8ToString(); - assertEquals(xContent, actualJson); + JsonNode expectedTree = mapper.readTree(xContent); + JsonNode actualTree = mapper.readTree(actualJson); + + assertEquals(expectedTree, actualTree); } public void testParse() throws IOException { - String xContent = "{" + "\"method\":\"FIXED_VALUES\"," + "\"defaultFill\":[1.0,2.0,3.0],\"integerSensitive\":false" + "}"; - double[] defaultFill = { 1.0, 2.0, 3.0 }; - ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); + ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, map); try ( XContentParser parser = JsonXContent.jsonXContent @@ -73,22 +107,24 @@ public void testParse() throws IOException { ImputationOption parsedOption = ImputationOption.parse(parser); assertEquals(imputationOption.getMethod(), parsedOption.getMethod()); - assertTrue(imputationOption.getDefaultFill().isPresent()); - assertTrue(parsedOption.getDefaultFill().isPresent()); - assertEquals(imputationOption.getDefaultFill().get().length, parsedOption.getDefaultFill().get().length); - for (int i = 0; i < imputationOption.getDefaultFill().get().length; i++) { - assertEquals(imputationOption.getDefaultFill().get()[i], parsedOption.getDefaultFill().get()[i], 0); - } + assertTrue(imputationOption.getDefaultFill().size() > 0); + assertTrue(parsedOption.getDefaultFill().size() > 0); + + // The assertEquals method checks if the two maps are equal. The Map interface's equals method ensures that + // the maps are considered equal if they contain the same key-value pairs, regardless of the order in which + // they were inserted. + assertEquals(imputationOption.getDefaultFill(), parsedOption.getDefaultFill()); } } public void testEqualsAndHashCode() { double[] defaultFill1 = { 1.0, 2.0, 3.0 }; - double[] defaultFill2 = { 4.0, 5.0, 6.0 }; - ImputationOption option1 = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); - ImputationOption option2 = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); - ImputationOption option3 = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill2), false); + Map map1 = randomMap(defaultFill1); + + ImputationOption option1 = new ImputationOption(ImputationMethod.FIXED_VALUES, map1); + ImputationOption option2 = new ImputationOption(ImputationMethod.FIXED_VALUES, map1); + ImputationOption option3 = new ImputationOption(ImputationMethod.PREVIOUS); // Test reflexivity assertTrue(option1.equals(option1)); @@ -98,7 +134,7 @@ public void testEqualsAndHashCode() { assertTrue(option2.equals(option1)); // Test transitivity - ImputationOption option2Clone = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); + ImputationOption option2Clone = new ImputationOption(ImputationMethod.FIXED_VALUES, map1); assertTrue(option1.equals(option2)); assertTrue(option2.equals(option2Clone)); assertTrue(option1.equals(option2Clone)); diff --git a/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java b/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java index 47118ca8a..836f0eafa 100644 --- a/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java @@ -643,7 +643,7 @@ public void testParseBuckets() throws InstantiationException, true ); - Optional parsedResult = searchFeatureDao.parseBucket(bucket, Arrays.asList(featureId)); + Optional parsedResult = searchFeatureDao.parseBucket(bucket, Arrays.asList(featureId), false); assertTrue(parsedResult.isPresent()); double[] parsedCardinality = parsedResult.get(); diff --git a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoParamTests.java b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoParamTests.java index 6868a3b9f..c9cdd44cb 100644 --- a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoParamTests.java +++ b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoParamTests.java @@ -267,11 +267,15 @@ private Object[] getFeaturesForPeriodData() { return new Object[] { new Object[] { asList(max), asList(maxName), new double[] { maxValue }, }, new Object[] { asList(percentiles), asList(percentileName), new double[] { percentileValue } }, - new Object[] { asList(missing), asList(missingName), null }, - new Object[] { asList(infinity), asList(infinityName), null }, + // we keep missing data + new Object[] { asList(missing), asList(missingName), new double[] { Double.NaN } }, + new Object[] { asList(infinity), asList(infinityName), new double[] { Double.NaN } }, new Object[] { asList(max, percentiles), asList(maxName, percentileName), new double[] { maxValue, percentileValue } }, new Object[] { asList(max, percentiles), asList(percentileName, maxName), new double[] { percentileValue, maxValue } }, - new Object[] { asList(max, percentiles, missing), asList(maxName, percentileName, missingName), null }, }; + new Object[] { + asList(max, percentiles, missing), + asList(maxName, percentileName, missingName), + new double[] { maxValue, percentileValue, Double.NaN } }, }; } private Object[] getFeaturesForSampledPeriodsData() { diff --git a/src/test/java/org/opensearch/timeseries/indices/rest/handler/IntervalCalculationTests.java b/src/test/java/org/opensearch/timeseries/indices/rest/handler/IntervalCalculationTests.java new file mode 100644 index 000000000..96ed15438 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/indices/rest/handler/IntervalCalculationTests.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.indices.rest.handler; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.AggregationPrep; +import org.opensearch.timeseries.rest.handler.IntervalCalculation; +import org.opensearch.timeseries.rest.handler.IntervalCalculation.IntervalRecommendationListener; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class IntervalCalculationTests extends OpenSearchTestCase { + + private IntervalCalculation intervalCalculation; + private Clock clock; + private ActionListener mockIntervalListener; + private AggregationPrep mockAggregationPrep; + private Client mockClient; + private SecurityClientUtil mockClientUtil; + private User user; + private Map mockTopEntity; + private IntervalTimeConfiguration mockIntervalConfig; + private LongBounds mockLongBounds; + private Config mockConfig; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + mockIntervalListener = mock(ActionListener.class); + mockAggregationPrep = mock(AggregationPrep.class); + mockClient = mock(Client.class); + mockClientUtil = mock(SecurityClientUtil.class); + user = TestHelpers.randomUser(); + mockTopEntity = mock(Map.class); + mockIntervalConfig = mock(IntervalTimeConfiguration.class); + mockLongBounds = mock(LongBounds.class); + mockConfig = mock(Config.class); + + intervalCalculation = new IntervalCalculation( + mockConfig, + mock(TimeValue.class), + mockClient, + mockClientUtil, + user, + AnalysisType.AD, + clock, + mock(SearchFeatureDao.class), + System.currentTimeMillis(), + mockTopEntity + ); + } + + public void testOnResponseExpirationEpochMsPassed() { + long expirationEpochMs = clock.millis() - 1000; // Expired 1 second ago + + IntervalRecommendationListener listener = intervalCalculation.new IntervalRecommendationListener( + mockIntervalListener, new SearchSourceBuilder(), mockIntervalConfig, expirationEpochMs, mockLongBounds + ); + + Histogram histogram = mock(Histogram.class); + when(histogram.getName()).thenReturn(AggregationPrep.AGGREGATION); + Aggregations aggs = new Aggregations(Arrays.asList(histogram)); + SearchResponseSections sections = new SearchResponseSections( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0), + aggs, + null, + false, + null, + null, + 1 + ); + listener.onResponse(new SearchResponse(sections, null, 0, 0, 0, 0L, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY)); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ValidationException.class); + verify(mockIntervalListener).onFailure(argumentCaptor.capture()); + ValidationException validationException = argumentCaptor.getValue(); + assertEquals(CommonMessages.TIMEOUT_ON_INTERVAL_REC, validationException.getMessage()); + assertEquals(ValidationIssueType.TIMEOUT, validationException.getType()); + assertEquals(ValidationAspect.MODEL, validationException.getAspect()); + } + + /** + * AggregationPrep.validateAndRetrieveHistogramAggregation throws ValidationException because + * response.getAggregations() returns null. + */ + public void testOnFailure() { + long expirationEpochMs = clock.millis() - 1000; // Expired 1 second ago + SearchResponse mockResponse = mock(SearchResponse.class); + + when(mockConfig.getHistoryIntervals()).thenReturn(40); + + IntervalRecommendationListener listener = intervalCalculation.new IntervalRecommendationListener( + mockIntervalListener, new SearchSourceBuilder(), mockIntervalConfig, expirationEpochMs, mockLongBounds + ); + + listener.onResponse(mockResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ValidationException.class); + verify(mockIntervalListener).onFailure(argumentCaptor.capture()); + ValidationException validationException = argumentCaptor.getValue(); + assertEquals(CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, validationException.getMessage()); + assertEquals(ValidationIssueType.AGGREGATION, validationException.getType()); + assertEquals(ValidationAspect.MODEL, validationException.getAspect()); + } +} diff --git a/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java index 126d18492..3ea8d0fec 100644 --- a/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java @@ -205,7 +205,7 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc waitUntil(() -> { try { ADTask task = getADTask(taskId); - return !TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(task.getState()); + return HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(task.getState()); } catch (IOException e) { return false; } diff --git a/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java index dd6e66ab9..7939e522c 100644 --- a/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java @@ -90,7 +90,6 @@ public void setUp() throws Exception { actionFilters, tarnsportStatemanager, modelManager, - featureManager, cacheProvider, forecastCacheProvider, entityColdStarter,