diff --git a/build.gradle b/build.gradle index dd839d743..464360b3e 100644 --- a/build.gradle +++ b/build.gradle @@ -278,7 +278,7 @@ evaluationDependsOnChildren() task release(type: Copy, group: 'build') { dependsOn allprojects*.tasks.build from(zipTree(project.tasks.bundlePlugin.outputs.files.getSingleFile())) - into "build/plugins/opendistro-anomaly-detection" + into "build/plugins/opensearch-anomaly-detection" includeEmptyDirs = false // ES versions < 6.3 have a top-level opensearch directory inside the plugin zip which we need to remove eachFile { it.path = it.path - "opensearch/" } @@ -327,6 +327,20 @@ List jacocoExclusions = [ 'org.opensearch.ad.indices.AnomalyDetectionIndices', 'org.opensearch.ad.transport.handler.MultiEntityResultHandler', 'org.opensearch.ad.util.ThrowingSupplierWrapper', + 'org.opensearch.ad.transport.EntityResultTransportAction', + 'org.opensearch.ad.transport.EntityResultTransportAction.*', + 'org.opensearch.ad.transport.AnomalyResultTransportAction.*', + 'org.opensearch.ad.transport.ProfileNodeResponse', + 'org.opensearch.ad.transport.ADResultBulkResponse', + 'org.opensearch.ad.transport.AggregationType', + 'org.opensearch.ad.EntityProfileRunner', + 'org.opensearch.ad.NodeStateManager', + 'org.opensearch.ad.util.BulkUtil', + 'org.opensearch.ad.util.ExceptionUtil', + 'org.opensearch.ad.feature.SearchFeatureDao', + 'org.opensearch.ad.feature.CompositeRetriever.*', + 'org.opensearch.ad.feature.ScriptMaker', + 'org.opensearch.ad.ml.EntityModel', ] jacocoTestCoverageVerification { @@ -431,11 +445,11 @@ afterEvaluate { prefix '/usr' license 'ASL-2.0' - maintainer 'OpenDistro for Elasticsearch Team ' - url 'https://opendistro.github.io/for-elasticsearch/downloads.html' + maintainer 'OpenSearch ' + url 'https://opensearch.org/downloads.html' summary ''' - Anomaly Detection plugin for OpenDistro for Elasticsearch. - Reference documentation can be found at https://opendistro.github.io/for-elasticsearch-docs/. + Anomaly Detection plugin for OpenSearch. + Reference documentation can be found at https://docs-beta.opensearch.org/docs/ad/. '''.stripIndent().replace('\n', ' ').trim() } diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index b99dddbb1..c9f640ccc 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -468,12 +468,14 @@ private void indexAnomalyResult( String detectorId = jobParameter.getName(); detectorEndRunExceptionCount.remove(detectorId); try { + // reset error if different from previously recorded one + detectionStateHandler.saveError(response.getError(), detectorId); // skipping writing to the result index if not necessary // For a single-entity detector, the result is not useful if error is null // and rcf score (thus anomaly grade/confidence) is null. - // For a multi-entity detector, we don't need to save on the detector level. - // We always return 0 rcf score if there is no error. - if (response.getAnomalyScore() <= 0 && response.getError() == null) { + // For a HCAD detector, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { return; } IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) ((AnomalyDetectorJob) jobParameter).getWindowDelay(); @@ -499,7 +501,6 @@ private void indexAnomalyResult( indexUtil.getSchemaVersion(ADIndex.RESULT) ); anomalyResultHandler.index(anomalyResult, detectorId); - detectionStateHandler.saveError(response.getError(), detectorId); } catch (Exception e) { log.error("Failed to index anomaly result for " + detectorId, e); } finally { diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java index 94a4f841e..7620df170 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java @@ -35,6 +35,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -57,6 +58,7 @@ import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.ScriptMaker; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.ml.CheckpointDao; @@ -68,6 +70,11 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityColdStartWorker; +import org.opensearch.ad.ratelimit.ResultWriteWorker; import org.opensearch.ad.rest.RestAnomalyDetectorJobAction; import org.opensearch.ad.rest.RestDeleteAnomalyDetectorAction; import org.opensearch.ad.rest.RestExecuteAnomalyDetectorAction; @@ -82,6 +89,7 @@ import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.ad.settings.LegacyOpenDistroAnomalyDetectorSettings; +import org.opensearch.ad.settings.NumericSetting; import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.stats.StatNames; @@ -176,6 +184,7 @@ import org.opensearch.jobscheduler.spi.JobSchedulerExtension; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; @@ -313,6 +322,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { EnabledSetting.getInstance().init(clusterService); + NumericSetting.getInstance().init(clusterService); this.client = client; this.threadPool = threadPool; Settings settings = environment.settings(); @@ -331,35 +341,25 @@ public Collection createComponents( xContentRegistry, interpolator, clientUtil, - threadPool, settings, - clusterService + clusterService, + gson ); JvmService jvmService = new JvmService(environment.settings()); RandomCutForestSerDe rcfSerde = new RandomCutForestSerDe(); - CheckpointDao checkpoint = new CheckpointDao( - client, - clientUtil, - CommonName.CHECKPOINT_INDEX_NAME, - gson, - rcfSerde, - HybridThresholdingModel.class, - getClock(), - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - anomalyDetectionIndices, - AnomalyDetectorSettings.MAX_BULK_CHECKPOINT_SIZE, - AnomalyDetectorSettings.CHECKPOINT_BULK_PER_SECOND - ); double modelMaxSizePercent = AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings); + ADCircuitBreakerService adCircuitBreakerService = new ADCircuitBreakerService(jvmService).init(); + MemoryTracker memoryTracker = new MemoryTracker( jvmService, modelMaxSizePercent, AnomalyDetectorSettings.DESIRED_MODEL_SIZE_PERCENTAGE, clusterService, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + adCircuitBreakerService ); ModelPartitioner modelPartitioner = new ModelPartitioner( @@ -396,6 +396,60 @@ public Collection createComponents( AD_THREAD_POOL_NAME ); + long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + + CheckpointDao checkpoint = new CheckpointDao( + client, + clientUtil, + CommonName.CHECKPOINT_INDEX_NAME, + gson, + rcfSerde, + HybridThresholdingModel.class, + anomalyDetectionIndices, + AnomalyDetectorSettings.MAX_CHECKPOINT_BYTES + ); + + Random random = new Random(42); + + CheckpointWriteWorker checkpointWriteQueue = new CheckpointWriteWorker( + heapSizeBytes, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + CommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + EntityCache cache = new PriorityCache( + checkpoint, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + memoryTracker, + AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES, + getClock(), + clusterService, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + checkpointWriteQueue, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT + ); + + CacheProvider cacheProvider = new CacheProvider(cache); + EntityColdStarter entityColdStarter = new EntityColdStarter( getClock(), threadPool, @@ -416,10 +470,29 @@ public Collection createComponents( AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES, featureManager, + settings, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.MAX_SMALL_STATES, - checkpoint, - settings + checkpointWriteQueue + ); + + EntityColdStartWorker coldstartQueue = new EntityColdStartWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + entityColdStarter, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager ); ModelManager modelManager = new ModelManager( @@ -447,24 +520,81 @@ public Collection createComponents( memoryTracker ); - EntityCache cache = new PriorityCache( - checkpoint, - AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, - AnomalyDetectorSettings.CHECKPOINT_TTL, - AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, - memoryTracker, - modelManager, - AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES, - getClock(), + MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( + client, + settings, + threadPool, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService + ); + + ResultWriteWorker resultWriteQueue = new ResultWriteWorker( + heapSizeBytes, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, + random, + adCircuitBreakerService, + threadPool, settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + multiEntityResultHandler, + xContentRegistry, + stateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + CheckpointReadWorker checkpointReadQueue = new CheckpointReadWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, threadPool, - AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings) + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + stateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue ); - CacheProvider cacheProvider = new CacheProvider(cache); + ColdEntityWorker coldEntityQueue = new ColdEntityWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointReadQueue, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager + ); HashRing hashRing = new HashRing(nodeFilter, getClock(), settings); @@ -505,8 +635,8 @@ public Collection createComponents( .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) .build(); - adStats = new ADStats(indexUtils, modelManager, stats); - ADCircuitBreakerService adCircuitBreakerService = new ADCircuitBreakerService(jvmService).init(); + adStats = new ADStats(stats); + this.detectorStateHandler = new DetectionStateHandler( client, settings, @@ -520,17 +650,6 @@ public Collection createComponents( stateManager ); - MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( - client, - settings, - threadPool, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService, - stateManager - ); - adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); adTaskManager = new ADTaskManager( settings, @@ -598,7 +717,14 @@ public Collection createComponents( cacheProvider, adTaskManager, adBatchTaskRunner, - adSearchHandler + adSearchHandler, + coldstartQueue, + resultWriteQueue, + checkpointReadQueue, + checkpointWriteQueue, + coldEntityQueue, + entityColdStarter, + new ScriptMaker() ); } @@ -618,7 +744,9 @@ public List> getExecutorBuilders(Settings settings) { new ScalingExecutorBuilder( AD_THREAD_POOL_NAME, 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 4), + // HCAD can be heavy after supporting 1 million entities. + // Limit to use at most half of the processors. + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 2), TimeValue.timeValueMinutes(10), AD_THREAD_POOL_PREFIX + AD_THREAD_POOL_NAME ), @@ -635,61 +763,92 @@ public List> getExecutorBuilders(Settings settings) { @Override public List> getSettings() { List> enabledSetting = EnabledSetting.getInstance().getSettings(); + List> numericSetting = NumericSetting.getInstance().getSettings(); List> systemSetting = ImmutableList .of( - LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, - LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, - LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, - LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, + // HCAD cache + LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + // Detector config LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, - LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, - LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + AnomalyDetectorSettings.DETECTION_INTERVAL, + AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + // Fault tolerance + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, - LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, - LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, - LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, - LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, - LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, - LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, - LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, - LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, - LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, - AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, - AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, - AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, AnomalyDetectorSettings.REQUEST_TIMEOUT, - AnomalyDetectorSettings.DETECTION_INTERVAL, - AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, - AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, - AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, AnomalyDetectorSettings.COOLDOWN_MINUTES, AnomalyDetectorSettings.BACKOFF_MINUTES, AnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, AnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, + // result index rollover + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + // resource usage control + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + AnomalyDetectorSettings.INDEX_PRESSURE_HARD_LIMIT, AnomalyDetectorSettings.MAX_PRIMARY_SHARDS, + // Security + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + // Historical + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, - AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE + AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + // rate limiting + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS, + // query limit + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.PAGE_SIZE ); - return unmodifiableList(Stream.concat(enabledSetting.stream(), systemSetting.stream()).collect(Collectors.toList())); + return unmodifiableList( + Stream + .of(enabledSetting.stream(), systemSetting.stream(), numericSetting.stream()) + .reduce(Stream::concat) + .orElseGet(Stream::empty) + .collect(Collectors.toList()) + ); } @Override diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index 1e5ddf037..78d7f81a5 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -74,7 +74,6 @@ import org.opensearch.common.xcontent.XContentParseException; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHits; @@ -250,7 +249,7 @@ private void prepareProfile( onGetDetectorForPrepare(listener, profilesToCollect); } }, exception -> { - if (exception instanceof IndexNotFoundException) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { logger.info(exception.getMessage()); onGetDetectorForPrepare(listener, profilesToCollect); } else { @@ -364,7 +363,7 @@ private ActionListener onGetDetectorState( listener.onResponse(profileBuilder.build()); } }, exception -> { - if (exception instanceof IndexNotFoundException) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { // detector state index is not created yet listener.onResponse(new DetectorProfile.Builder().build()); } else { @@ -482,7 +481,7 @@ private ActionListener onInittedEver( listener.onResponse(profileBuilder.build()); } }, exception -> { - if (exception instanceof IndexNotFoundException) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { // anomaly result index is not created yet processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); } else { @@ -525,7 +524,7 @@ private ActionListener onPollRCFUpdates( Exception causeException = (Exception) cause; if (ExceptionUtil .isException(causeException, ResourceNotFoundException.class, ExceptionUtil.RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE) - || (causeException instanceof IndexNotFoundException + || (ExceptionUtil.isIndexNotAvailable(causeException) && causeException.getMessage().contains(CommonName.CHECKPOINT_INDEX_NAME))) { // cannot find checkpoint // We don't want to show the estimated time remaining to initialize diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java index 592a4a315..8efa3ea13 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -76,6 +75,7 @@ public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureMa * @param detector anomaly detector instance * @param startTime detection period start time * @param endTime detection period end time + * @param context stored thread context * @param listener handle anomaly result * @throws IOException - if a user gives wrong query input when defining a detector */ @@ -96,6 +96,7 @@ public void executeDetector( // This also requires front-end change to handle error message correspondingly // We return empty list for now to avoid breaking front-end listener.onResponse(Collections.emptyList()); + return; } ActionListener entityAnomalyResultListener = ActionListener .wrap( @@ -119,7 +120,7 @@ public void executeDetector( ActionListener.wrap(features -> { List entityResults = modelManager.getPreviewResults(features.getProcessedFeatures()); List sampledEntityResults = sample( - parsePreviewResult(detector, features, entityResults, Arrays.asList(entity)), + parsePreviewResult(detector, features, entityResults, entity), maxPreviewResults ); multiEntitiesResponseListener.onResponse(new EntityAnomalyResult(sampledEntityResults)); @@ -146,6 +147,7 @@ private void onFailure(Exception e, ActionListener> listener // We return empty list for now to avoid breaking front-end if (e instanceof OpenSearchSecurityException) { listener.onFailure(e); + return; } listener.onResponse(Collections.emptyList()); } @@ -154,7 +156,7 @@ private List parsePreviewResult( AnomalyDetector detector, Features features, List results, - List entity + Entity entity ) { // unprocessedFeatures[][], each row is for one date range. // For example, unprocessedFeatures[0][2] is for the first time range, the third feature diff --git a/src/main/java/org/opensearch/ad/EntityModelSize.java b/src/main/java/org/opensearch/ad/EntityModelSize.java deleted file mode 100644 index 3233fce99..000000000 --- a/src/main/java/org/opensearch/ad/EntityModelSize.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package org.opensearch.ad; - -public interface EntityModelSize { - /** - * Gets an entity's model sizes - * - * @param detectorId Detector Id - * @param entityModelId Entity's model Id - * @return the entity's memory size - */ - long getModelSize(String detectorId, String entityModelId); -} diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/ad/EntityProfileRunner.java index 1d2d1b623..0355ee1bd 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java @@ -28,11 +28,11 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CATEGORY_FIELD_LIMIT; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.security.InvalidParameterException; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -47,17 +47,20 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfile; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.model.EntityState; import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.settings.NumericSetting; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileRequest; import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.ParseUtils; import org.opensearch.client.Client; +import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentParser; @@ -74,6 +77,7 @@ public class EntityProfileRunner extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(EntityProfileRunner.class); static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; + static final String NO_ENTITY = "Cannot find entity"; private Client client; private NamedXContentRegistry xContentRegistry; @@ -93,7 +97,7 @@ public EntityProfileRunner(Client client, NamedXContentRegistry xContentRegistry */ public void profile( String detectorId, - String entityValue, + Entity entityValue, Set profilesToCollect, ActionListener listener ) { @@ -112,16 +116,15 @@ public void profile( ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId); - List categoryField = detector.getCategoryField(); - if (categoryField == null || categoryField.size() == 0) { + List categoryFields = detector.getCategoryField(); + int maxCategoryFields = NumericSetting.maxCategoricalFields(); + if (categoryFields == null || categoryFields.size() == 0) { listener.onFailure(new InvalidParameterException(NOT_HC_DETECTOR_ERR_MSG)); - } else if (categoryField.size() > CATEGORY_FIELD_LIMIT) { + } else if (categoryFields.size() > maxCategoryFields) { listener - .onFailure( - new InvalidParameterException(CommonErrorMessages.CATEGORICAL_FIELD_NUMBER_SURPASSED + CATEGORY_FIELD_LIMIT) - ); + .onFailure(new InvalidParameterException(CommonErrorMessages.getTooManyCategoricalFieldErr(maxCategoryFields))); } else { - prepareEntityProfile(listener, detectorId, entityValue, profilesToCollect, detector, categoryField.get(0)); + validateEntity(entityValue, categoryFields, detectorId, profilesToCollect, detector, listener); } } catch (Exception t) { listener.onFailure(t); @@ -132,10 +135,71 @@ public void profile( }, listener::onFailure)); } + /** + * Verify if the input entity exists or not in case of typos. + * + * If a user deletes the entity after job start, then we will not be able to + * get this entity in the index. For this case, we will not return a profile + * for this entity even if it's running on some data node. the entity's model + * will be deleted by another entity or by maintenance due to long inactivity. + * + * @param entity Entity accessor + * @param categoryFields category fields defined for a detector + * @param detectorId Detector Id + * @param profilesToCollect Profile to collect from the input + * @param detector Detector config accessor + * @param listener Callback to send responses. + */ + private void validateEntity( + Entity entity, + List categoryFields, + String detectorId, + Set profilesToCollect, + AnomalyDetector detector, + ActionListener listener + ) { + Map attributes = entity.getAttributes(); + if (attributes == null || attributes.size() != categoryFields.size()) { + listener.onFailure(new InvalidParameterException("Empty entity attributes")); + return; + } + for (String field : categoryFields) { + if (false == attributes.containsKey(field)) { + listener.onFailure(new InvalidParameterException("Cannot find " + field)); + return; + } + } + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); + + SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) + .preference(Preference.LOCAL.toString()); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + try { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new InvalidParameterException(NO_ENTITY)); + return; + } + prepareEntityProfile(listener, detectorId, entity, profilesToCollect, detector, categoryFields.get(0)); + } catch (Exception e) { + listener.onFailure(new InvalidParameterException(NO_ENTITY)); + return; + } + }, e -> listener.onFailure(new InvalidParameterException(NO_ENTITY)))); + + } + private void prepareEntityProfile( ActionListener listener, String detectorId, - String entityValue, + Entity entityValue, Set profilesToCollect, AnomalyDetector detector, String categoryField @@ -146,18 +210,13 @@ private void prepareEntityProfile( .execute( EntityProfileAction.INSTANCE, request, - ActionListener - .wrap( - r -> getJob(detectorId, categoryField, entityValue, profilesToCollect, detector, r, listener), - listener::onFailure - ) + ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, detector, r, listener), listener::onFailure) ); } private void getJob( String detectorId, - String categoryField, - String entityValue, + Entity entityValue, Set profilesToCollect, AnomalyDetector detector, EntityProfileResponse entityProfileResponse, @@ -194,7 +253,7 @@ private void getJob( ); if (profilesToCollect.contains(EntityProfileName.MODELS)) { - EntityProfile.Builder builder = new EntityProfile.Builder(categoryField, entityValue); + EntityProfile.Builder builder = new EntityProfile.Builder(); if (false == job.isEnabled()) { delegateListener.onResponse(builder.build()); } else { @@ -207,7 +266,6 @@ private void getJob( profileStateRelated( entityProfileResponse.getTotalUpdates(), detectorId, - categoryField, entityValue, profilesToCollect, detector, @@ -220,7 +278,7 @@ private void getJob( long enabledTimeMs = job.getEnabledTime().toEpochMilli(); SearchRequest lastSampleTimeRequest = createLastSampleTimeRequest(detectorId, enabledTimeMs, entityValue); - EntityProfile.Builder builder = new EntityProfile.Builder(categoryField, entityValue); + EntityProfile.Builder builder = new EntityProfile.Builder(); Optional isActiveOp = entityProfileResponse.isActive(); if (isActiveOp.isPresent()) { @@ -252,12 +310,12 @@ private void getJob( listener.onFailure(e); } } else { - sendUnknownState(profilesToCollect, categoryField, entityValue, true, listener); + sendUnknownState(profilesToCollect, entityValue, true, listener); } }, exception -> { if (exception instanceof IndexNotFoundException) { logger.info(exception.getMessage()); - sendUnknownState(profilesToCollect, categoryField, entityValue, true, listener); + sendUnknownState(profilesToCollect, entityValue, true, listener); } else { logger.error(CommonErrorMessages.FAIL_TO_GET_PROFILE_MSG + detectorId, exception); listener.onFailure(exception); @@ -268,40 +326,37 @@ private void getJob( private void profileStateRelated( long totalUpdates, String detectorId, - String categoryField, - String entityValue, + Entity entityValue, Set profilesToCollect, AnomalyDetector detector, AnomalyDetectorJob job, MultiResponsesDelegateActionListener delegateListener ) { if (totalUpdates == 0) { - sendUnknownState(profilesToCollect, categoryField, entityValue, false, delegateListener); + sendUnknownState(profilesToCollect, entityValue, false, delegateListener); } else if (false == job.isEnabled()) { - sendUnknownState(profilesToCollect, categoryField, entityValue, false, delegateListener); + sendUnknownState(profilesToCollect, entityValue, false, delegateListener); } else if (totalUpdates >= requiredSamples) { - sendRunningState(profilesToCollect, categoryField, entityValue, delegateListener); + sendRunningState(profilesToCollect, entityValue, delegateListener); } else { - sendInitState(profilesToCollect, categoryField, entityValue, detector, totalUpdates, delegateListener); + sendInitState(profilesToCollect, entityValue, detector, totalUpdates, delegateListener); } } /** * Send unknown state back * @param profilesToCollect Profiles to Collect - * @param categoryField Category field * @param entityValue Entity value * @param immediate whether we should terminate workflow and respond immediately * @param delegateListener Delegate listener */ private void sendUnknownState( Set profilesToCollect, - String categoryField, - String entityValue, + Entity entityValue, boolean immediate, ActionListener delegateListener ) { - EntityProfile.Builder builder = new EntityProfile.Builder(categoryField, entityValue); + EntityProfile.Builder builder = new EntityProfile.Builder(); if (profilesToCollect.contains(EntityProfileName.STATE)) { builder.state(EntityState.UNKNOWN); } @@ -314,11 +369,10 @@ private void sendUnknownState( private void sendRunningState( Set profilesToCollect, - String categoryField, - String entityValue, + Entity entityValue, MultiResponsesDelegateActionListener delegateListener ) { - EntityProfile.Builder builder = new EntityProfile.Builder(categoryField, entityValue); + EntityProfile.Builder builder = new EntityProfile.Builder(); if (profilesToCollect.contains(EntityProfileName.STATE)) { builder.state(EntityState.RUNNING); } @@ -331,13 +385,12 @@ private void sendRunningState( private void sendInitState( Set profilesToCollect, - String categoryField, - String entityValue, + Entity entityValue, AnomalyDetector detector, long updates, MultiResponsesDelegateActionListener delegateListener ) { - EntityProfile.Builder builder = new EntityProfile.Builder(categoryField, entityValue); + EntityProfile.Builder builder = new EntityProfile.Builder(); if (profilesToCollect.contains(EntityProfileName.STATE)) { builder.state(EntityState.INIT); } @@ -349,14 +402,55 @@ private void sendInitState( delegateListener.onResponse(builder.build()); } - private SearchRequest createLastSampleTimeRequest(String detectorId, long enabledTime, String entityValue) { + private SearchRequest createLastSampleTimeRequest(String detectorId, long enabledTime, Entity entity) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); String path = "entity"; - String entityValueFieldName = path + ".value"; - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, entityValue); - NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityValueFilterQuery, ScoreMode.None); - boolQueryBuilder.filter(nestedQueryBuilder); + String entityName = path + ".name"; + String entityValue = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + /* + * each attribute pair corresponds to a nested query like + "nested": { + "query": { + "bool": { + "filter": [ + { + "term": { + "entity.name": { + "value": "turkey4", + "boost": 1 + } + } + }, + { + "term": { + "entity.value": { + "value": "Turkey", + "boost": 1 + } + } + } + ] + } + }, + "path": "entity", + "ignore_unmapped": false, + "score_mode": "none", + "boost": 1 + } + },*/ + BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); + + TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); + nestedBoolQueryBuilder.filter(entityNameFilterQuery); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); + nestedBoolQueryBuilder.filter(entityValueFilterQuery); + + NestedQueryBuilder nestedNameQueryBuilder = new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None); + boolQueryBuilder.filter(nestedNameQueryBuilder); + } boolQueryBuilder.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); diff --git a/src/main/java/org/opensearch/ad/MemoryTracker.java b/src/main/java/org/opensearch/ad/MemoryTracker.java index 8904cd13f..fb2429a3f 100644 --- a/src/main/java/org/opensearch/ad/MemoryTracker.java +++ b/src/main/java/org/opensearch/ad/MemoryTracker.java @@ -34,6 +34,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; @@ -50,7 +51,7 @@ public class MemoryTracker { public enum Origin { SINGLE_ENTITY_DETECTOR, - MULTI_ENTITY_DETECTOR, + HC_DETECTOR, HISTORICAL_SINGLE_ENTITY_DETECTOR, } @@ -66,6 +67,7 @@ public enum Origin { // we observe threshold model uses a fixed size array and the size is the same private int thresholdModelBytes; private int sampleSize; + private ADCircuitBreakerService adCircuitBreakerService; /** * Constructor @@ -75,13 +77,15 @@ public enum Origin { * @param modelDesiredSizePercentage percentage of heap for the desired size of a model * @param clusterService Cluster service object * @param sampleSize The sample size used by stream samplers in a RCF forest + * @param adCircuitBreakerService Memory circuit breaker */ public MemoryTracker( JvmService jvmService, double modelMaxSizePercentage, double modelDesiredSizePercentage, ClusterService clusterService, - int sampleSize + int sampleSize, + ADCircuitBreakerService adCircuitBreakerService ) { this.totalMemoryBytes = 0; this.totalMemoryBytesByOrigin = new EnumMap(Origin.class); @@ -95,19 +99,19 @@ public MemoryTracker( .addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); this.thresholdModelBytes = 180_000; this.sampleSize = sampleSize; - } - - public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { - return canAllocateReserved(detectorId, estimateModelSize(rcf)); + this.adCircuitBreakerService = adCircuitBreakerService; } /** - * @param detectorId Detector Id, used in error message - * @param requiredBytes required bytes in memory - * @return whether there is memory required for AD + * This function derives from the old code: https://tinyurl.com/2eaabja6 + * + * @param detectorId Detector Id + * @param rcf Random cut forest model + * @return true if there is enough memory; otherwise throw LimitExceededException. */ - public synchronized boolean canAllocateReserved(String detectorId, long requiredBytes) { - if (reservedMemoryBytes + requiredBytes <= heapLimitBytes) { + public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { + long requiredBytes = estimateModelSize(rcf); + if (canAllocateReserved(requiredBytes)) { return true; } else { throw new LimitExceededException( @@ -124,12 +128,21 @@ public synchronized boolean canAllocateReserved(String detectorId, long required } /** - * Whether allocating memory is allowed + * @param requiredBytes required bytes to allocate + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough reserved memory. + */ + public synchronized boolean canAllocateReserved(long requiredBytes) { + return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); + } + + /** * @param bytes required bytes - * @return true if allowed; false otherwise + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough overall memory. */ public synchronized boolean canAllocate(long bytes) { - return totalMemoryBytes + bytes <= heapLimitBytes; + return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; } public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) { @@ -243,8 +256,8 @@ public long getTotalMemoryBytes() { } /** - * In case of bugs/race conditions when allocating/releasing memory, sync used bytes - * infrequently by recomputing memory usage. + * In case of bugs/race conditions or users dyanmically changing dedicated/shared + * cache size, sync used bytes infrequently by recomputing memory usage. * @param origin Origin * @param totalBytes total bytes from recomputing * @param reservedBytes reserved bytes from recomputing @@ -256,6 +269,7 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) { return false; } + LOG .info( String diff --git a/src/main/java/org/opensearch/ad/NodeState.java b/src/main/java/org/opensearch/ad/NodeState.java index b2b0eaaf1..08c41bd55 100644 --- a/src/main/java/org/opensearch/ad/NodeState.java +++ b/src/main/java/org/opensearch/ad/NodeState.java @@ -44,13 +44,13 @@ public class NodeState implements ExpiringState { private AnomalyDetector detectorDef; // number of partitions private int partitonNumber; - // checkpoint fetch time + // last access time private Instant lastAccessTime; // last detection error recorded in result index. Used by DetectorStateHandler // to check if the error for a detector has changed or not. If changed, trigger indexing. private Optional lastDetectionError; - // last training error. Used to save cold start error by a concurrent cold start thread. - private Optional lastColdStartException; + // last error. + private Optional exception; // flag indicating whether checkpoint for the detector exists private boolean checkPointExists; // clock to get current time @@ -64,7 +64,7 @@ public NodeState(String detectorId, Clock clock) { this.partitonNumber = -1; this.lastAccessTime = clock.instant(); this.lastDetectionError = Optional.empty(); - this.lastColdStartException = Optional.empty(); + this.exception = Optional.empty(); this.checkPointExists = false; this.clock = clock; this.coldStartRunning = false; @@ -148,19 +148,19 @@ public void setLastDetectionError(String lastError) { /** * - * @return last cold start exception if any + * @return last exception if any */ - public Optional getLastColdStartException() { + public Optional getException() { refreshLastUpdateTime(); - return lastColdStartException; + return exception; } /** * - * @param lastColdStartError last cold start exception if any + * @param exception exception to record */ - public void setLastColdStartException(AnomalyDetectionException lastColdStartError) { - this.lastColdStartException = Optional.ofNullable(lastColdStartError); + public void setException(AnomalyDetectionException exception) { + this.exception = Optional.ofNullable(exception); refreshLastUpdateTime(); } diff --git a/src/main/java/org/opensearch/ad/NodeStateManager.java b/src/main/java/org/opensearch/ad/NodeStateManager.java index 915468d7a..b1244a022 100644 --- a/src/main/java/org/opensearch/ad/NodeStateManager.java +++ b/src/main/java/org/opensearch/ad/NodeStateManager.java @@ -30,13 +30,13 @@ import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -49,6 +49,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.BackPressureRouting; import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Client; import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; @@ -74,8 +75,6 @@ public class NodeStateManager implements MaintenanceState, CleanState { private final Clock clock; private final Settings settings; private final Duration stateTtl; - // last time we are throttled due to too much index pressure - private Instant lastIndexThrottledTime; public static final String NO_ERROR = "no_error"; @@ -109,7 +108,6 @@ public NodeStateManager( this.clock = clock; this.settings = settings; this.stateTtl = stateTtl; - this.lastIndexThrottledTime = Instant.MIN; } /** @@ -160,7 +158,7 @@ private ActionListener onGetDetectorResponse(String adID, ActionLis } String xc = response.getSourceAsString(); - LOG.info("Fetched anomaly detector: {}", xc); + LOG.debug("Fetched anomaly detector: {}", xc); try ( XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) @@ -286,33 +284,64 @@ public void setLastDetectionError(String adID, String error) { } /** - * Set last cold start error of a detector + * Get a detector's exception. The method has side effect. + * We reset error after calling the method because + * 1) We record a detector's exception in each interval. There is no need + * to record it twice. + * 2) EndRunExceptions can stop job running. We only want to send the same + * signal once for each exception. * @param adID detector id - * @param exception exception, can be null + * @return the detector's exception */ - public void setLastColdStartException(String adID, AnomalyDetectionException exception) { - NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); - state.setLastColdStartException(exception); - } - - /** - * Get last cold start exception of a detector. The method has side effect. - * We reset error after calling the method since cold start exception can stop job running. - * @param adID detector id - * @return last cold start exception for the detector - */ - public Optional fetchColdStartException(String adID) { + public Optional fetchExceptionAndClear(String adID) { NodeState state = states.get(adID); if (state == null) { return Optional.empty(); } - Optional exception = state.getLastColdStartException(); - // since cold start exception can stop job running, we set it to null after using it once. - exception.ifPresent(e -> setLastColdStartException(adID, null)); + Optional exception = state.getException(); + exception.ifPresent(e -> state.setException(null)); return exception; } + /** + * For single-stream detector, we have one exception per interval. When + * an interval starts, it fetches and clears the exception. + * For HCAD, there can be one exception per entity. To not bloat memory + * with exceptions, we will keep only one exception. An exception has 3 purposes: + * 1) stop detector if nothing else works; + * 2) increment error stats to ticket about high-error domain + * 3) debugging. + * + * For HCAD, we record all entities' exceptions in anomaly results. So 3) + * is covered. As long as we keep one exception among all exceptions, 2) + * is covered. So the only thing we have to pay attention is to keep EndRunException. + * When overriding an exception, EndRunException has priority. + * @param detectorId Detector Id + * @param e Exception to set + */ + public void setException(String detectorId, Exception e) { + if (e == null || Strings.isEmpty(detectorId)) { + return; + } + NodeState state = states.computeIfAbsent(detectorId, d -> new NodeState(detectorId, clock)); + Optional exception = state.getException(); + if (exception.isPresent()) { + Exception higherPriorityException = ExceptionUtil.selectHigherPriorityException(e, exception.get()); + if (higherPriorityException != e) { + return; + } + } + + AnomalyDetectionException adExep = null; + if (e instanceof AnomalyDetectionException) { + adExep = (AnomalyDetectionException) e; + } else { + adExep = new AnomalyDetectionException(detectorId, e); + } + state.setException(adExep); + } + /** * Whether last cold start for the detector is running * @param adID detector ID @@ -342,12 +371,4 @@ public Releasable markColdStartRunning(String adID) { } }; } - - public Instant getLastIndexThrottledTime() { - return lastIndexThrottledTime; - } - - public void setLastIndexThrottledTime(Instant lastIndexThrottledTime) { - this.lastIndexThrottledTime = lastIndexThrottledTime; - } } diff --git a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java b/src/main/java/org/opensearch/ad/caching/CacheBuffer.java index ae1b75f00..b68a59821 100644 --- a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java +++ b/src/main/java/org/opensearch/ad/caching/CacheBuffer.java @@ -29,9 +29,11 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Map.Entry; import java.util.Optional; +import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -40,13 +42,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.ExpiringState; -import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.MemoryTracker.Origin; -import org.opensearch.ad.ml.CheckpointDao; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.RequestPriority; /** * We use a layered cache to manage active entities’ states. We have a two-level @@ -66,54 +68,55 @@ * top minimumCapacity active entities (last X entities in priorityList) as in dedicated * cache and all others in shared cache. */ -public class CacheBuffer implements ExpiringState, MaintenanceState { +public class CacheBuffer implements ExpiringState { private static final Logger LOG = LogManager.getLogger(CacheBuffer.class); // max entities to track per detector private final int MAX_TRACKING_ENTITIES = 1000000; - private final int minimumCapacity; + // the reserved cache size. So no matter how many entities there are, we will + // keep the size for minimum capacity entities + private int minimumCapacity; + // key -> value private final ConcurrentHashMap> items; // memory consumption per entity private final long memoryConsumptionPerEntity; private final MemoryTracker memoryTracker; - private final CheckpointDao checkpointDao; private final Duration modelTtl; private final String detectorId; private Instant lastUsedTime; - private final long reservedBytes; + private long reservedBytes; private final PriorityTracker priorityTracker; private final Clock clock; + private final CheckpointWriteWorker checkpointWriteQueue; + private final Random random; public CacheBuffer( int minimumCapacity, long intervalSecs, - CheckpointDao checkpointDao, long memoryConsumptionPerEntity, MemoryTracker memoryTracker, Clock clock, Duration modelTtl, - String detectorId + String detectorId, + CheckpointWriteWorker checkpointWriteQueue, + Random random ) { - if (minimumCapacity <= 0) { - throw new IllegalArgumentException("minimum capacity should be larger than 0"); - } - this.minimumCapacity = minimumCapacity; + this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; + setMinimumCapacity(minimumCapacity); this.items = new ConcurrentHashMap<>(); - - this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; this.memoryTracker = memoryTracker; - this.checkpointDao = checkpointDao; this.modelTtl = modelTtl; this.detectorId = detectorId; this.lastUsedTime = clock.instant(); - this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; this.clock = clock; this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.checkpointWriteQueue = checkpointWriteQueue; + this.random = random; } /** @@ -166,7 +169,7 @@ private void put(String entityModelId, ModelState value, float prio // Since we have already considered them while allocating CacheBuffer, // skip bookkeeping. if (!sharedCacheEmpty()) { - memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.MULTI_ENTITY_DETECTOR); + memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); } } else { update(entityModelId); @@ -240,14 +243,19 @@ public ModelState remove(String keyToRemove) { ModelState valueRemoved = items.remove(keyToRemove); if (valueRemoved != null) { - // if we releasing a shared cache item, release memory as well. if (!reserved) { - memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.MULTI_ENTITY_DETECTOR); + // release in shared memory + memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); } - checkpointDao.write(valueRemoved, keyToRemove); EntityModel modelRemoved = valueRemoved.getModel(); if (modelRemoved != null) { + // null model has only samples. For null model we save a checkpoint + // regardless of last checkpoint time. whether If we don't save, + // we throw the new samples and might never be able to initialize the model + boolean isNullModel = modelRemoved.getRcf() == null || modelRemoved.getThreshold() == null; + checkpointWriteQueue.write(valueRemoved, isNullModel, RequestPriority.MEDIUM); + modelRemoved.clear(); } } @@ -289,8 +297,8 @@ public boolean canReplaceWithinDetector(float priority) { if (items.isEmpty()) { return false; } - Entry minPriorityItem = priorityTracker.getMinimumPriority(); - return minPriorityItem != null && priority > minPriorityItem.getValue(); + Optional> minPriorityItem = priorityTracker.getMinimumPriority(); + return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); } /** @@ -306,8 +314,13 @@ public ModelState replace(String entityModelId, ModelState> maintenance() { + List> modelsToSave = new ArrayList<>(); + List> removedStates = new ArrayList<>(); items.entrySet().stream().forEach(entry -> { String entityModelId = entry.getKey(); try { @@ -324,21 +337,45 @@ public void maintenance() { // put: not a problem as we are unlikely to maintain an entry that's not // already in the cache // remove method saves checkpoint as well - remove(entityModelId); - } else { - // we can have ConcurrentModificationException when serializing - // and updating rcf model at the same time. To prevent this, - // we need to have a deep copy of models or have a lock. Both - // options are costly. - // As we are gonna retry serializing either when the entity is - // evicted out of cache or during the next maintenance period, - // don't do anything when the exception happens. - checkpointDao.write(modelState, entityModelId); + removedStates.add(remove(entityModelId)); + } else if (random.nextInt(6) == 0) { + // checkpoint is relatively big compared to other queued requests + // save checkpoints with 1/6 probability as we expect to save + // all every 6 hours statistically + // + // Background: + // We will save a checkpoint when + // + // (a)removing the model from cache. + // (b) cold start + // (c) no complete model only a few samples. If we don't save new samples, + // we will never be able to have enough samples for a trained mode. + // (d) periodically save in case of exceptions. + // + // This branch is doing d). Previously, I will do it every hour for all + // in-cache models. Consider we are moving to 1M entities, this will bring + // the cluster in a heavy payload every hour. That's why I am doing it randomly + // (expected 6 hours for each checkpoint statistically). + // + // I am doing it random since maintaining a state of which one has been saved + // and which one hasn't are not cheap. Also, the models in the cache can be + // dynamically changing. Will have to maintain the state in the removing logic. + // Random is a lazy way to deal with this as it is stateless and statistically sound. + // + // If a checkpoint does not fall into the 6-hour bucket in a particular scenario, the model + // is stale (i.e., we don't recover from the freshest model in disaster.). + // + // All in all, randomness is mostly due to performance and easy maintenance. + modelsToSave.add(modelState); } + } catch (Exception e) { LOG.warn("Failed to finish maintenance for model id " + entityModelId, e); } }); + + checkpointWriteQueue.writeAll(modelsToSave, detectorId, false, RequestPriority.MEDIUM); + return removedStates; } /** @@ -388,9 +425,9 @@ public void clear() { // not a problem as we are releasing memory in MemoryTracker. // The newly added one loses references and soon GC will collect it. // We have memory tracking correction to fix incorrect memory usage record. - memoryTracker.releaseMemory(getReservedBytes(), true, Origin.MULTI_ENTITY_DETECTOR); + memoryTracker.releaseMemory(getReservedBytes(), true, Origin.HC_DETECTOR); if (!sharedCacheEmpty()) { - memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.MULTI_ENTITY_DETECTOR); + memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.HC_DETECTOR); } items.clear(); priorityTracker.clearPriority(); @@ -449,11 +486,19 @@ public String getDetectorId() { return detectorId; } - public List> getAllModels() { + public List> getAllModels() { return items.values().stream().collect(Collectors.toList()); } public PriorityTracker getPriorityTracker() { return priorityTracker; } + + public void setMinimumCapacity(int minimumCapacity) { + if (minimumCapacity < 0) { + throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); + } + this.minimumCapacity = minimumCapacity; + this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; + } } diff --git a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java b/src/main/java/org/opensearch/ad/caching/DoorKeeper.java index f1f15d2a8..792ce1a42 100644 --- a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java +++ b/src/main/java/org/opensearch/ad/caching/DoorKeeper.java @@ -38,9 +38,7 @@ import com.google.common.hash.Funnels; /** - * A bloom filter placed in front of inactive entity cache to - * filter out unpopular items that are not likely to appear more - * than once. + * A bloom filter with regular reset. * * Reference: https://arxiv.org/abs/1512.00727 * diff --git a/src/main/java/org/opensearch/ad/caching/EntityCache.java b/src/main/java/org/opensearch/ad/caching/EntityCache.java index 06b108902..5e59a20ae 100644 --- a/src/main/java/org/opensearch/ad/caching/EntityCache.java +++ b/src/main/java/org/opensearch/ad/caching/EntityCache.java @@ -26,29 +26,31 @@ package org.opensearch.ad.caching; +import java.util.Collection; import java.util.List; +import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.ad.CleanState; import org.opensearch.ad.DetectorModelSize; -import org.opensearch.ad.EntityModelSize; import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.ModelProfile; -public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize, EntityModelSize { +public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize { /** * Get the ModelState associated with the entity. May or may not load the * ModelState depending on the underlying cache's eviction policy. * * @param modelId Model Id * @param detector Detector config object - * @param datapoint The most recent data point - * @param entityName The Entity's name * @return the ModelState associated with the model or null if no cached item * for the entity */ - ModelState get(String modelId, AnomalyDetector detector, double[] datapoint, String entityName); + ModelState get(String modelId, AnomalyDetector detector); /** * Get the number of active entities of a detector @@ -109,4 +111,47 @@ public interface EntityCache extends MaintenanceState, CleanState, DetectorModel * milliseconds when the entity's state is lastly used. Otherwise, return -1. */ long getLastActiveMs(String detectorId, String entityModelId); + + /** + * Release memory when memory circuit breaker is open + */ + void releaseMemoryForOpenCircuitBreaker(); + + /** + * Select candidate entities for which we can load models + * @param cacheMissEntities Cache miss entities + * @param detectorId Detector Id + * @param detector Detector object + * @return A list of entities that are admitted into the cache as a result of the + * update and the left-over entities + */ + Pair, List> selectUpdateCandidate( + Collection cacheMissEntities, + String detectorId, + AnomalyDetector detector + ); + + /** + * + * @param detector Detector config + * @param toUpdate Model state candidate + * @return if we can host the given model state + */ + boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate); + + /** + * + * @param detectorId Detector Id + * @return a detector's model information + */ + List getAllModelProfile(String detectorId); + + /** + * Gets an entity's model sizes + * + * @param detectorId Detector Id + * @param entityModelId Entity's model Id + * @return the entity's memory size + */ + Optional getModelProfile(String detectorId, String entityModelId); } diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/ad/caching/PriorityCache.java index d71dfdf18..2d904660c 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/ad/caching/PriorityCache.java @@ -26,34 +26,34 @@ package org.opensearch.ad.caching; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEDICATED_CACHE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; +import java.util.PriorityQueue; import java.util.Random; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.OpenSearchException; -import org.opensearch.action.ActionListener; -import org.opensearch.action.support.TransportActions; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.MemoryTracker.Origin; @@ -61,20 +61,20 @@ import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.ml.CheckpointDao; import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; +import org.opensearch.common.Strings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.threadpool.ThreadPool; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import com.google.common.util.concurrent.RateLimiter; public class PriorityCache implements EntityCache { private final Logger LOG = LogManager.getLogger(PriorityCache.class); @@ -82,22 +82,25 @@ public class PriorityCache implements EntityCache { // detector id -> CacheBuffer, weight based private final Map activeEnities; private final CheckpointDao checkpointDao; - private final int dedicatedCacheSize; + private volatile int dedicatedCacheSize; // LRU Cache private Cache> inActiveEntities; private final MemoryTracker memoryTracker; - private final ModelManager modelManager; private final ReentrantLock maintenanceLock; private final int numberOfTrees; private final Clock clock; private final Duration modelTtl; - private final int numMinSamples; + // A bloom filter placed in front of inactive entity cache to + // filter out unpopular items that are not likely to appear more + // than once. private Map doorKeepers; - private Instant cooldownStart; - private int coolDownMinutes; private ThreadPool threadPool; private Random random; - private RateLimiter cacheMissHandlingLimiter; + private CheckpointWriteWorker checkpointWriteQueue; + // iterating through all of inactive entities is heavy. We don't want to do + // it again and again for no obvious benefits. + private Instant lastInActiveEntityMaintenance; + protected int maintenanceFreqConstant; public PriorityCache( CheckpointDao checkpointDao, @@ -105,26 +108,30 @@ public PriorityCache( Duration inactiveEntityTtl, int maxInactiveStates, MemoryTracker memoryTracker, - ModelManager modelManager, int numberOfTrees, Clock clock, ClusterService clusterService, Duration modelTtl, - int numMinSamples, - Settings settings, ThreadPool threadPool, - int cacheMissRateHandlingLimiter + CheckpointWriteWorker checkpointWriteQueue, + int maintenanceFreqConstant ) { this.checkpointDao = checkpointDao; - this.dedicatedCacheSize = dedicatedCacheSize; + this.activeEnities = new ConcurrentHashMap<>(); + this.dedicatedCacheSize = dedicatedCacheSize; + clusterService.getClusterSettings().addSettingsUpdateConsumer(DEDICATED_CACHE_SIZE, (it) -> { + this.dedicatedCacheSize = it; + this.setDedicatedCacheSizeListener(); + this.tryClearUpMemory(); + }, this::validateDedicatedCacheSize); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.tryClearUpMemory()); + this.memoryTracker = memoryTracker; - this.modelManager = modelManager; this.maintenanceLock = new ReentrantLock(); this.numberOfTrees = numberOfTrees; this.clock = clock; this.modelTtl = modelTtl; - this.numMinSamples = numMinSamples; this.doorKeepers = new ConcurrentHashMap<>(); this.inActiveEntities = CacheBuilder @@ -134,34 +141,30 @@ public PriorityCache( .concurrencyLevel(1) .build(); - this.cooldownStart = Instant.MIN; - this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); this.threadPool = threadPool; this.random = new Random(42); - - this.cacheMissHandlingLimiter = RateLimiter.create(cacheMissRateHandlingLimiter); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(MAX_CACHE_MISS_HANDLING_PER_SECOND, it -> this.cacheMissHandlingLimiter = RateLimiter.create(it)); + this.checkpointWriteQueue = checkpointWriteQueue; + this.lastInActiveEntityMaintenance = Instant.MIN; + this.maintenanceFreqConstant = maintenanceFreqConstant; } @Override - public ModelState get(String modelId, AnomalyDetector detector, double[] datapoint, String entityName) { + public ModelState get(String modelId, AnomalyDetector detector) { String detectorId = detector.getDetectorId(); CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); ModelState modelState = buffer.get(modelId); // during maintenance period, stop putting new entries - if (modelState == null) { + if (!maintenanceLock.isLocked() && modelState == null) { DoorKeeper doorKeeper = doorKeepers .computeIfAbsent( detectorId, id -> { // reset every 60 intervals return new DoorKeeper( - AnomalyDetectorSettings.DOOR_KEEPER_MAX_INSERTION, + AnomalyDetectorSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, AnomalyDetectorSettings.DOOR_KEEPER_FAULSE_POSITIVE_RATE, - detector.getDetectionIntervalDuration().multipliedBy(60), + detector.getDetectionIntervalDuration().multipliedBy(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ), clock ); } @@ -173,188 +176,302 @@ public ModelState get(String modelId, AnomalyDetector detector, dou return null; } - ModelState state = inActiveEntities.getIfPresent(modelId); - - // compute updated priority - // We don’t want to admit the latest entity for correctness by throwing out a - // hot entity. We have a priority (time-decayed count) sensitive to - // the number of hits, length of time, and sampling interval. Examples: - // 1) an entity from a 5-minute interval detector that is hit 5 times in the - // past 25 minutes should have an equal chance of using the cache along with - // an entity from a 1-minute interval detector that is hit 5 times in the past - // 5 minutes. - // 2) It might be the case that the frequency of entities changes dynamically - // during run-time. For example, entity A showed up for the first 500 times, - // but entity B showed up for the next 500 times. Our priority should give - // entity B higher priority than entity A as time goes by. - // 3) Entity A will have a higher priority than entity B if A runs - // for a longer time given other things are equal. - // - // We ensure fairness by using periods instead of absolute duration. Entity A - // accessed once three intervals ago should have the same priority with entity B - // accessed once three periods ago, though they belong to detectors of different - // intervals. - float priority = 0; - if (state != null) { - priority = state.getPriority(); + try { + ModelState state = inActiveEntities.get(modelId, new Callable>() { + @Override + public ModelState call() { + return new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + } + }); + + // make sure no model has been stored due to previous race conditions + state.setModel(null); + + // compute updated priority + // We don’t want to admit the latest entity for correctness by throwing out a + // hot entity. We have a priority (time-decayed count) sensitive to + // the number of hits, length of time, and sampling interval. Examples: + // 1) an entity from a 5-minute interval detector that is hit 5 times in the + // past 25 minutes should have an equal chance of using the cache along with + // an entity from a 1-minute interval detector that is hit 5 times in the past + // 5 minutes. + // 2) It might be the case that the frequency of entities changes dynamically + // during run-time. For example, entity A showed up for the first 500 times, + // but entity B showed up for the next 500 times. Our priority should give + // entity B higher priority than entity A as time goes by. + // 3) Entity A will have a higher priority than entity B if A runs + // for a longer time given other things are equal. + // + // We ensure fairness by using periods instead of absolute duration. Entity A + // accessed once three intervals ago should have the same priority with entity B + // accessed once three periods ago, though they belong to detectors of different + // intervals. + + // update state using new priority or create a new one + state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + + // adjust shared memory in case we have used dedicated cache memory for other detectors + if (random.nextInt(maintenanceFreqConstant) == 1) { + tryClearUpMemory(); + } + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to update priority of [{}]", modelId), e); } - priority = buffer.getPriorityTracker().getUpdatedPriority(priority); - // update state using new priority or create a new one - if (state != null) { - state.setPriority(priority); - } else { - EntityModel model = new EntityModel(modelId, new ArrayDeque<>(), null, null); - state = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - } + } - if (random.nextInt(10_000) == 1) { - // clear up memory with 1/10000 probability since this operation is costly, but we have to do it from time to time. - // e.g., we need to adjust shared entity memory size if some reserved memory gets allocated. Use 10_000 since our - // query limit is 1k by default and we can have 10 detectors: 10 * 1k. We also do this in hourly maintenance window no - // matter what. - tryClearUpMemory(); - } - if (!maintenanceLock.isLocked() - && cacheMissHandlingLimiter.tryAcquire() - && hostIfPossible(buffer, detectorId, modelId, entityName, detector, state, priority)) { - addSample(state, datapoint); - inActiveEntities.invalidate(modelId); - } else { - // put to inactive cache if we cannot host or get the lock or get rate permits - // only keep weights in inactive cache to keep it small. - // It can be dangerous to exceed a few dozen MBs, especially - // in small heap machine like t2. - inActiveEntities.put(modelId, state); + return modelState; + } + + private Optional> getStateFromInactiveEntiiyCache(String modelId) { + if (modelId == null) { + return Optional.empty(); + } + + // null if not even recorded in inActiveEntities yet because of doorKeeper + return Optional.ofNullable(inActiveEntities.getIfPresent(modelId)); + } + + @Override + public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate) { + if (toUpdate == null) { + return false; + } + String modelId = toUpdate.getModelId(); + String detectorId = toUpdate.getDetectorId(); + + if (Strings.isEmpty(modelId) || Strings.isEmpty(detectorId)) { + return false; + } + + CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + + Optional> state = getStateFromInactiveEntiiyCache(modelId); + if (false == state.isPresent()) { + return false; + } + + ModelState modelState = state.get(); + + float priority = modelState.getPriority(); + + toUpdate.setLastUsedTime(clock.instant()); + toUpdate.setPriority(priority); + + // current buffer's dedicated cache has free slots or can allocate in shared cache + if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + // buffer.put will call MemoryTracker.consumeMemory + buffer.put(modelId, toUpdate); + return true; + } + + if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + // buffer.put will call MemoryTracker.consumeMemory + buffer.put(modelId, toUpdate); + return true; + } + + // can replace an entity in the same CacheBuffer living in reserved or shared cache + if (buffer.canReplaceWithinDetector(priority)) { + ModelState removed = buffer.replace(modelId, toUpdate); + // null in the case of some other threads have emptied the queue at + // the same time so there is nothing to replace + if (removed != null) { + addIntoInactiveCache(removed); + return true; } } - return modelState; + // If two threads try to remove the same entity and add their own state, the 2nd remove + // returns null and only the first one succeeds. + float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + CacheBuffer bufferToRemove = bufferToRemoveEntity.getLeft(); + String entityModelId = bufferToRemoveEntity.getMiddle(); + ModelState removed = null; + if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { + buffer.put(modelId, toUpdate); + addIntoInactiveCache(removed); + return true; + } + + return false; } - /** - * Whether host an entity is possible - * @param buffer the destination buffer for the given entity - * @param detectorId Detector Id - * @param modelId Model Id - * @param entityName Entity's name - * @param detector Detector Config - * @param state State to host - * @param priority The entity's priority - * @return true if possible; false otherwise - */ - private boolean hostIfPossible( - CacheBuffer buffer, + private void addIntoInactiveCache(ModelState removed) { + if (removed == null) { + return; + } + // set last used time for profile API so that we know when an entities is evicted + removed.setLastUsedTime(clock.instant()); + removed.setModel(null); + inActiveEntities.put(removed.getModelId(), removed); + } + + private void addEntity(List destination, Entity entity, String detectorId) { + // It's possible our doorkeepr prevented the entity from entering inactive entities cache + if (entity != null) { + Optional modelId = entity.getModelId(detectorId); + if (modelId.isPresent() && inActiveEntities.getIfPresent(modelId.get()) != null) { + destination.add(entity); + } + } + } + + @Override + public Pair, List> selectUpdateCandidate( + Collection cacheMissEntities, String detectorId, - String modelId, - String entityName, - AnomalyDetector detector, - ModelState state, - float priority + AnomalyDetector detector ) { + List hotEntities = new ArrayList<>(); + List coldEntities = new ArrayList<>(); + + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer == null) { + // don't want to create side-effects by creating a CacheBuffer + // In current implementation, this branch is impossible as we call + // PriorityCache.get method before invoking this method. The + // PriorityCache.get method creates a CacheBuffer if not present. + // Since this method is public, need to deal with this case in case of misuse. + return Pair.of(hotEntities, coldEntities); + } + + Iterator cacheMissEntitiesIter = cacheMissEntities.iterator(); // current buffer's dedicated cache has free slots - // thread safe as each detector has one thread at one time and only the - // thread can access its buffer. - if (buffer.dedicatedCacheAvailable()) { - buffer.put(modelId, state); - } else if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + while (cacheMissEntitiesIter.hasNext() && buffer.dedicatedCacheAvailable()) { + addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + } + + while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { // can allocate in shared cache // race conditions can happen when multiple threads evaluating this condition. // This is a problem as our AD memory usage is close to full and we put - // more things than we planned. One model in multi-entity case is small, + // more things than we planned. One model in HCAD is small, // it is fine we exceed a little. We have regular maintenance to remove // extra memory usage. - buffer.put(modelId, state); - } else if (buffer.canReplaceWithinDetector(priority)) { + addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + } + + // check if we can replace anything in dedicated or shared cache + // have a copy since we need to do the iteration twice: one for + // dedicated cache and one for shared cache + List otherBufferReplaceCandidates = new ArrayList<>(); + + while (cacheMissEntitiesIter.hasNext()) { // can replace an entity in the same CacheBuffer living in reserved // or shared cache // thread safe as each detector has one thread at one time and only the // thread can access its buffer. - ModelState removed = buffer.replace(modelId, state); - if (removed != null) { - // set last used time for profile API so that we know when an entities is evicted - removed.setLastUsedTime(clock.instant()); - inActiveEntities.put(removed.getModelId(), removed); + Entity entity = cacheMissEntitiesIter.next(); + Optional modelId = entity.getModelId(detectorId); + + if (false == modelId.isPresent()) { + continue; + } + + Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); + if (false == state.isPresent()) { + // not even recorded in inActiveEntities yet because of doorKeeper + continue; + } + + ModelState modelState = state.get(); + float priority = modelState.getPriority(); + + if (buffer.canReplaceWithinDetector(priority)) { + addEntity(hotEntities, entity, detectorId); + } else { + // re-evaluate replacement condition in other buffers + otherBufferReplaceCandidates.add(entity); } - } else { + } + + // record current minimum priority among all detectors to save redundant + // scanning of all CacheBuffers + CacheBuffer bufferToRemove = null; + float minPriority = Float.MIN_VALUE; + + // check if we can replace in other CacheBuffer + cacheMissEntitiesIter = otherBufferReplaceCandidates.iterator(); + + while (cacheMissEntitiesIter.hasNext()) { // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. + Entity entity = cacheMissEntitiesIter.next(); + Optional modelId = entity.getModelId(detectorId); + + if (false == modelId.isPresent()) { + continue; + } + + Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); + if (false == inactiveState.isPresent()) { + // empty state should not stand a chance to replace others + continue; + } + + ModelState state = inactiveState.get(); + + float priority = state.getPriority(); float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); - Entry bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); - CacheBuffer bufferToRemove = bufferToRemoveEntity.getKey(); - String entityModelId = bufferToRemoveEntity.getValue(); - ModelState removed = null; - if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { - buffer.put(modelId, state); - // set last used time for profile API so that we know when an entities is evicted - removed.setLastUsedTime(clock.instant()); - inActiveEntities.put(removed.getModelId(), removed); - } else { - return false; + + if (scaledPriority <= minPriority) { + // not even larger than the minPriority, we can put this to coldEntities + addEntity(coldEntities, entity, detectorId); + continue; } - } - maybeRestoreOrTrainModel(modelId, entityName, state); - return true; - } + // Float.MIN_VALUE means we need to re-iterate through all CacheBuffers + if (minPriority == Float.MIN_VALUE) { + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + bufferToRemove = bufferToRemoveEntity.getLeft(); + minPriority = bufferToRemoveEntity.getRight(); + } - private void addSample(ModelState stateToPromote, double[] datapoint) { - // add samples - Queue samples = stateToPromote.getModel().getSamples(); - samples.add(datapoint); - // only keep the recent numMinSamples - while (samples.size() > this.numMinSamples) { - samples.remove(); + if (bufferToRemove != null) { + addEntity(hotEntities, entity, detectorId); + // reset minPriority after the replacement so that we need to iterate all CacheBuffer + // again + minPriority = Float.MIN_VALUE; + } else { + // after trying everything, we can now safely put this to cold entities list + addEntity(coldEntities, entity, detectorId); + } } - } - private void maybeRestoreOrTrainModel(String modelId, String entityName, ModelState state) { - EntityModel entityModel = state.getModel(); - // rate limit in case of OpenSearchRejectedExecutionException from get threadpool whose queue capacity is 1k - if (entityModel != null - && (entityModel.getRcf() == null || entityModel.getThreshold() == null) - && cooldownStart.plus(Duration.ofMinutes(coolDownMinutes)).isBefore(clock.instant())) { - checkpointDao - .restoreModelCheckpoint( - modelId, - ActionListener - .wrap(checkpoint -> modelManager.processEntityCheckpoint(checkpoint, modelId, entityName, state), exception -> { - Throwable cause = Throwables.getRootCause(exception); - if (cause instanceof IndexNotFoundException) { - modelManager.processEntityCheckpoint(Optional.empty(), modelId, entityName, state); - } else if (cause instanceof RejectedExecutionException - || TransportActions.isShardNotAvailableException(cause)) { - LOG.error("too many get AD model checkpoint requests or shard not avialble"); - cooldownStart = clock.instant(); - } else { - LOG.error("Fail to restore models for " + modelId, exception); - } - }) - ); - } + return Pair.of(hotEntities, coldEntities); } private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detectorId) { - return activeEnities.computeIfAbsent(detectorId, k -> { + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer == null) { long requiredBytes = getReservedDetectorMemory(detector); - tryClearUpMemory(); - if (memoryTracker.canAllocateReserved(detectorId, requiredBytes)) { - memoryTracker.consumeMemory(requiredBytes, true, Origin.MULTI_ENTITY_DETECTOR); + if (memoryTracker.canAllocateReserved(requiredBytes)) { + memoryTracker.consumeMemory(requiredBytes, true, Origin.HC_DETECTOR); long intervalSecs = detector.getDetectorIntervalInSeconds(); - return new CacheBuffer( + buffer = new CacheBuffer( dedicatedCacheSize, intervalSecs, - checkpointDao, memoryTracker.estimateModelSize(detector, numberOfTrees), memoryTracker, clock, modelTtl, - detectorId + detectorId, + checkpointWriteQueue, + random ); + activeEnities.put(detectorId, buffer); + // There can be race conditions between tryClearUpMemory and + // activeEntities.put above as tryClearUpMemory accesses activeEnities too. + // Put tryClearUpMemory after consumeMemory to prevent that. + tryClearUpMemory(); + } else { + throw new LimitExceededException(detectorId, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); } - // if hosting not allowed, exception will be thrown by isHostingAllowed - throw new LimitExceededException(detectorId, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); - }); + + } + return buffer; } private long getReservedDetectorMemory(AnomalyDetector detector) { @@ -372,29 +489,37 @@ private long getReservedDetectorMemory(AnomalyDetector detector) { * @param candidatePriority the candidate entity's priority * @return the CacheBuffer if we can find a CacheBuffer to make room for the candidate entity */ - private Entry canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { + private Triple canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { CacheBuffer minPriorityBuffer = null; - float minPriority = Float.MAX_VALUE; + float minPriority = candidatePriority; String minPriorityEntityModelId = null; for (Map.Entry entry : activeEnities.entrySet()) { CacheBuffer buffer = entry.getValue(); if (buffer != originBuffer && buffer.canRemove()) { - Entry priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); - float priority = priorityEntry.getValue(); + Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); + if (!priorityEntry.isPresent()) { + continue; + } + float priority = priorityEntry.get().getValue(); if (candidatePriority > priority && priority < minPriority) { minPriority = priority; minPriorityBuffer = buffer; - minPriorityEntityModelId = priorityEntry.getKey(); + minPriorityEntityModelId = priorityEntry.get().getKey(); } } } - return new SimpleImmutableEntry<>(minPriorityBuffer, minPriorityEntityModelId); + return Triple.of(minPriorityBuffer, minPriorityEntityModelId, minPriority); } + /** + * Clear up overused memory. Can happen due to race condition or other detectors + * consumes resources from shared memory. + * tryClearUpMemory is ran using AD threadpool because the function is expensive. + */ private void tryClearUpMemory() { try { if (maintenanceLock.tryLock()) { - clearMemory(); + threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); } else { threadPool.schedule(() -> { try { @@ -414,29 +539,46 @@ private void tryClearUpMemory() { private void clearMemory() { recalculateUsedMemory(); long memoryToShed = memoryTracker.memoryToShed(); - float minPriority = Float.MAX_VALUE; - CacheBuffer minPriorityBuffer = null; - String minPriorityEntityModelId = null; - while (memoryToShed > 0) { + PriorityQueue> removalCandiates = null; + if (memoryToShed > 0) { + // sort the triple in an ascending order of priority + removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft())); for (Map.Entry entry : activeEnities.entrySet()) { CacheBuffer buffer = entry.getValue(); - Entry priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); - float priority = priorityEntry.getValue(); - if (buffer.canRemove() && priority < minPriority) { - minPriority = priority; - minPriorityBuffer = buffer; - minPriorityEntityModelId = priorityEntry.getKey(); + Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); + if (!priorityEntry.isPresent()) { + continue; + } + float priority = priorityEntry.get().getValue(); + if (buffer.canRemove()) { + removalCandiates.add(Triple.of(priority, buffer, priorityEntry.get().getKey())); } } - if (minPriorityBuffer != null) { - minPriorityBuffer.remove(minPriorityEntityModelId); - long memoryReleased = minPriorityBuffer.getMemoryConsumptionPerEntity(); - memoryTracker.releaseMemory(memoryReleased, false, Origin.MULTI_ENTITY_DETECTOR); - memoryToShed -= memoryReleased; - } else { + } + while (memoryToShed > 0) { + if (false == removalCandiates.isEmpty()) { + Triple toRemove = removalCandiates.poll(); + CacheBuffer minPriorityBuffer = toRemove.getMiddle(); + String minPriorityEntityModelId = toRemove.getRight(); + + ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); + memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerEntity(); + addIntoInactiveCache(removed); + + if (minPriorityBuffer.canRemove()) { + // can remove another one + Optional> priorityEntry = minPriorityBuffer.getPriorityTracker().getMinimumScaledPriority(); + if (priorityEntry.isPresent()) { + removalCandiates.add(Triple.of(priorityEntry.get().getValue(), minPriorityBuffer, priorityEntry.get().getKey())); + } + } + } + + if (removalCandiates.isEmpty()) { break; } } + } /** @@ -450,7 +592,7 @@ private void recalculateUsedMemory() { reserved += buffer.getReservedBytes(); shared += buffer.getBytesInSharedCache(); } - memoryTracker.syncMemoryState(Origin.MULTI_ENTITY_DETECTOR, reserved + shared, reserved); + memoryTracker.syncMemoryState(Origin.HC_DETECTOR, reserved + shared, reserved); } /** @@ -473,10 +615,15 @@ public void maintenance() { activeEnities.remove(detectorId); cacheBuffer.clear(); } else { - cacheBuffer.maintenance(); + List> removedStates = cacheBuffer.maintenance(); + for (ModelState state : removedStates) { + addIntoInactiveCache(state); + } } }); - checkpointDao.flush(); + + maintainInactiveCache(); + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { String detectorId = doorKeeperEntry.getKey(); DoorKeeper doorKeeper = doorKeeperEntry.getValue(); @@ -607,22 +754,6 @@ public Map getModelSize(String detectorId) { return res; } - @Override - /** - * Gets an entity's model state - * - * @param detectorId detector id - * @param entityModelId entity model id - * @return the model state - */ - public long getModelSize(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { - return cacheBuffer.getMemoryConsumptionPerEntity(); - } - return -1L; - } - /** * Return the last active time of an entity's state. * @@ -648,4 +779,112 @@ public long getLastActiveMs(String detectorId, String entityModelId) { } return -1L; } + + @Override + public void releaseMemoryForOpenCircuitBreaker() { + maintainInactiveCache(); + + tryClearUpMemory(); + activeEnities.values().stream().forEach(cacheBuffer -> { + if (cacheBuffer.canRemove()) { + ModelState removed = cacheBuffer.remove(); + addIntoInactiveCache(removed); + } + }); + } + + private void maintainInactiveCache() { + if (lastInActiveEntityMaintenance.plus(this.modelTtl).isAfter(clock.instant())) { + // don't scan inactive cache too frequently as it is costly + return; + } + + // force maintenance of the cache. ref: https://tinyurl.com/pyy3p9v6 + inActiveEntities.cleanUp(); + + // // make sure no model has been stored due to bugs + for (ModelState state : inActiveEntities.asMap().values()) { + EntityModel model = state.getModel(); + if (model != null && (model.getRcf() != null || model.getThreshold() != null)) { + LOG + .warn( + new ParameterizedMessage( + "Inactive entity's model is null: [{}], [{}]. Maybe there are bugs.", + model.getRcf(), + model.getThreshold() + ) + ); + state.setModel(null); + } + } + + lastInActiveEntityMaintenance = clock.instant(); + } + + /** + * Called when dedicated cache size changes. Will adjust existing cache buffer's + * cache size + */ + private void setDedicatedCacheSizeListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(dedicatedCacheSize)); + } + + @Override + public List getAllModelProfile(String detectorId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + List res = new ArrayList<>(); + if (cacheBuffer != null) { + long size = cacheBuffer.getMemoryConsumptionPerEntity(); + cacheBuffer.getAllModels().forEach(entry -> { + EntityModel model = entry.getModel(); + Entity entity = null; + if (model != null && model.getEntity().isPresent()) { + entity = model.getEntity().get(); + } + res.add(new ModelProfile(entry.getModelId(), entity, size)); + }); + } + return res; + } + + /** + * Gets an entity's model state + * + * @param detectorId detector id + * @param entityModelId entity model id + * @return the model state + */ + @Override + public Optional getModelProfile(String detectorId, String entityModelId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { + EntityModel model = cacheBuffer.getModel(entityModelId).get(); + Entity entity = null; + if (model != null && model.getEntity().isPresent()) { + entity = model.getEntity().get(); + } + return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerEntity())); + } + return Optional.empty(); + } + + /** + * Throw an IllegalArgumentException even the dedicated size increases cannot + * be fulfilled. + * + * @param newDedicatedCacheSize the new dedicated cache size to validate + */ + private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { + if (this.dedicatedCacheSize < newDedicatedCacheSize) { + int delta = newDedicatedCacheSize - this.dedicatedCacheSize; + long totalIncreasedBytes = 0; + for (CacheBuffer cacheBuffer : activeEnities.values()) { + totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerEntity() * delta; + } + + if (false == memoryTracker.canAllocateReserved(totalIncreasedBytes)) { + throw new IllegalArgumentException("We don't have enough memory for the required change"); + } + } + } } diff --git a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java b/src/main/java/org/opensearch/ad/caching/PriorityTracker.java index aecdf7ae1..d0b0ee3ef 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java +++ b/src/main/java/org/opensearch/ad/caching/PriorityTracker.java @@ -184,28 +184,40 @@ public PriorityTracker(Clock clock, long intervalSecs, long landmarkEpoch, int m /** * Get the minimum priority entity and compute its scaled priority. * Used to compare entity priorities among detectors. - * @return the minimum priority entity's ID and scaled priority + * @return the minimum priority entity's ID and scaled priority or Optional.empty + * if the priority list is empty */ - public Entry getMinimumScaledPriority() { + public Optional> getMinimumScaledPriority() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } PriorityNode smallest = priorityList.first(); - return new SimpleImmutableEntry<>(smallest.key, getScaledPriority(smallest.priority)); + return Optional.of(new SimpleImmutableEntry<>(smallest.key, getScaledPriority(smallest.priority))); } /** * Get the minimum priority entity and compute its scaled priority. * Used to compare entity priorities within the same detector. - * @return the minimum priority entity's ID and scaled priority + * @return the minimum priority entity's ID and scaled priority or Optional.empty + * if the priority list is empty */ - public Entry getMinimumPriority() { + public Optional> getMinimumPriority() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } PriorityNode smallest = priorityList.first(); - return new SimpleImmutableEntry<>(smallest.key, smallest.priority); + return Optional.of(new SimpleImmutableEntry<>(smallest.key, smallest.priority)); } /** * - * @return the minimum priority entity's Id + * @return the minimum priority entity's Id or Optional.empty + * if the priority list is empty */ public Optional getMinimumPriorityEntityId() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } return Optional.of(priorityList).map(list -> list.first()).map(node -> node.key); } @@ -214,6 +226,9 @@ public Optional getMinimumPriorityEntityId() { * @return Get maximum priority entity's Id */ public Optional getHighestPriorityEntityId() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } return Optional.of(priorityList).map(list -> list.last()).map(node -> node.key); } diff --git a/src/main/java/org/opensearch/ad/constant/CommonErrorMessages.java b/src/main/java/org/opensearch/ad/constant/CommonErrorMessages.java index f3bf421f2..ba4ca9663 100644 --- a/src/main/java/org/opensearch/ad/constant/CommonErrorMessages.java +++ b/src/main/java/org/opensearch/ad/constant/CommonErrorMessages.java @@ -26,6 +26,8 @@ package org.opensearch.ad.constant; +import java.util.Locale; + public class CommonErrorMessages { public static final String AD_ID_MISSING_MSG = "AD ID is missing"; public static final String MODEL_ID_MISSING_MSG = "Model ID is missing"; @@ -36,6 +38,8 @@ public class CommonErrorMessages { public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window."; public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = "AD memory circuit is broken."; public static final String DISABLED_ERR_MSG = "AD plugin is disabled. To enable update plugins.anomaly_detection.enabled to true"; + // We need this invalid query tag to show proper error message on frontend + // refer to AD Dashboard code: https://tinyurl.com/8b5n8hat public static final String INVALID_SEARCH_QUERY_MSG = "Invalid search query."; public static final String ALL_FEATURES_DISABLED_ERR_MSG = "Having trouble querying data because all of your features have been disabled."; @@ -49,4 +53,11 @@ public class CommonErrorMessages { public static String DETECTOR_IS_RUNNING = "Detector is already running"; public static String DETECTOR_MISSING = "Detector is missing"; public static String AD_TASK_ACTION_MISSING = "AD task action is missing"; + public static final String BUG_RESPONSE = "We might have bugs."; + + private static final String TOO_MANY_CATEGORICAL_FIELD_ERR_MSG_FORMAT = "We can have only %d categorical field."; + + public static String getTooManyCategoricalFieldErr(int limit) { + return String.format(Locale.ROOT, TOO_MANY_CATEGORICAL_FIELD_ERR_MSG_FORMAT, limit); + } } diff --git a/src/main/java/org/opensearch/ad/constant/CommonMessageAttributes.java b/src/main/java/org/opensearch/ad/constant/CommonMessageAttributes.java deleted file mode 100644 index ed8840d6f..000000000 --- a/src/main/java/org/opensearch/ad/constant/CommonMessageAttributes.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package org.opensearch.ad.constant; - -public class CommonMessageAttributes { - - // ====================================== - // Json keys - // ====================================== - public static final String RCF_SCORE_JSON_KEY = "rCFScore"; - public static final String ID_JSON_KEY = "adID"; - public static final String MODEL_ID_JSON_KEY = "modelID"; - public static final String FEATURE_JSON_KEY = "features"; - public static final String CONFIDENCE_JSON_KEY = "confidence"; - public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; - public static final String QUEUE_JSON_KEY = "queue"; - public static final String START_JSON_KEY = "start"; - public static final String END_JSON_KEY = "end"; -} diff --git a/src/main/java/org/opensearch/ad/constant/CommonName.java b/src/main/java/org/opensearch/ad/constant/CommonName.java index e29e56e76..2a3721445 100644 --- a/src/main/java/org/opensearch/ad/constant/CommonName.java +++ b/src/main/java/org/opensearch/ad/constant/CommonName.java @@ -73,11 +73,17 @@ public class CommonName { public static final String MODELS = "models"; public static final String MODEL = "model"; public static final String INIT_PROGRESS = "init_progress"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String CATEGORICAL_FIELD = "category_field"; public static final String TOTAL_ENTITIES = "total_entities"; public static final String ACTIVE_ENTITIES = "active_entities"; public static final String ENTITY_INFO = "entity_info"; public static final String TOTAL_UPDATES = "total_updates"; + + // ====================================== + // Historical detectors + // ====================================== public static final String AD_TASK = "ad_task"; public static final String AD_TASK_REMOTE = "ad_task_remote"; public static final String CANCEL_TASK = "cancel_task"; @@ -107,4 +113,25 @@ public class CommonName { public static final String DATE_HISTOGRAM = "date_histogram"; // feature aggregation name public static final String FEATURE_AGGS = "feature_aggs"; + + // ====================================== + // Used in almost all components + // ====================================== + public static final String MODEL_ID_KEY = "model_id"; + public static final String DETECTOR_ID_KEY = "detector_id"; + public static final String ENTITY_KEY = "entity"; + + // ====================================== + // Used in toXContent + // ====================================== + public static final String RCF_SCORE_JSON_KEY = "rCFScore"; + public static final String ID_JSON_KEY = "adID"; + public static final String FEATURE_JSON_KEY = "features"; + public static final String CONFIDENCE_JSON_KEY = "confidence"; + public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; + public static final String QUEUE_JSON_KEY = "queue"; + public static final String START_JSON_KEY = "start"; + public static final String END_JSON_KEY = "end"; + public static final String VALUE_JSON_KEY = "value"; + public static final String ENTITIES_JSON_KEY = "entities"; } diff --git a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java b/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java new file mode 100644 index 000000000..bad4c06ec --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; +import org.opensearch.search.aggregations.metrics.Percentile; + +public class AbstractRetriever { + protected double parseAggregation(Aggregation aggregation) { + Double result = null; + if (aggregation instanceof SingleValue) { + result = ((SingleValue) aggregation).value(); + } else if (aggregation instanceof InternalTDigestPercentiles) { + Iterator percentile = ((InternalTDigestPercentiles) aggregation).iterator(); + if (percentile.hasNext()) { + result = percentile.next().getValue(); + } + } + return Optional + .ofNullable(result) + .orElseThrow(() -> new EndRunException("Failed to parse aggregation " + aggregation, true).countedInStats(false)); + } + + protected Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { + return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); + } + + protected Optional parseAggregations(Optional aggregations, List featureIds) { + return aggregations + .map(aggs -> aggs.asMap()) + .map( + map -> featureIds + .stream() + .mapToDouble(id -> Optional.ofNullable(map.get(id)).map(this::parseAggregation).orElse(Double.NaN)) + .toArray() + ) + .filter(result -> Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d))); + } +} diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java new file mode 100644 index 000000000..ea9054eca --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java @@ -0,0 +1,367 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.io.IOException; +import java.time.Clock; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.Feature; +import org.opensearch.ad.util.ParseUtils; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation.Bucket; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; + +/** + * + * Use pagination to fetch entities. If there are more than pageSize entities, + * we will fetch them in the next page. We implement pagination with composite query. + * Results are decomposed into pages. Each page encapsulates aggregated values for + * each entity and is sent to model nodes according to the hash ring mapping from + * entity model Id to a data node. + * + */ +public class CompositeRetriever extends AbstractRetriever { + public static final String AGG_NAME_COMP = "comp_agg"; + private static final Logger LOG = LogManager.getLogger(CompositeRetriever.class); + + private final long dataStartEpoch; + private final long dataEndEpoch; + private final AnomalyDetector anomalyDetector; + private final NamedXContentRegistry xContent; + private final Client client; + private int totalResults; + private int maxEntities; + private final int pageSize; + private long expirationEpochMs; + private Clock clock; + + public CompositeRetriever( + long dataStartEpoch, + long dataEndEpoch, + AnomalyDetector anomalyDetector, + NamedXContentRegistry xContent, + Client client, + long expirationEpochMs, + Clock clock, + Settings settings, + int maxEntitiesPerInterval, + int pageSize + ) { + this.dataStartEpoch = dataStartEpoch; + this.dataEndEpoch = dataEndEpoch; + this.anomalyDetector = anomalyDetector; + this.xContent = xContent; + this.client = client; + this.totalResults = 0; + this.maxEntities = maxEntitiesPerInterval; + this.pageSize = pageSize; + this.expirationEpochMs = expirationEpochMs; + this.clock = clock; + } + + // a constructor that provide default value of clock + public CompositeRetriever( + long dataStartEpoch, + long dataEndEpoch, + AnomalyDetector anomalyDetector, + NamedXContentRegistry xContent, + Client client, + long expirationEpochMs, + Settings settings, + int maxEntitiesPerInterval, + int pageSize + ) { + this( + dataStartEpoch, + dataEndEpoch, + anomalyDetector, + xContent, + client, + expirationEpochMs, + Clock.systemUTC(), + settings, + maxEntitiesPerInterval, + pageSize + ); + } + + /** + * @return an iterator over pages + * @throws IOException - if we cannot construct valid queries according to + * detector definition + */ + public PageIterator iterator() throws IOException { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + .gte(dataStartEpoch) + .lt(dataEndEpoch) + .format("epoch_millis"); + + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(anomalyDetector.getFilterQuery()).filter(rangeQuery); + + // multiple categorical fields are supported + CompositeAggregationBuilder composite = AggregationBuilders + .composite( + AGG_NAME_COMP, + anomalyDetector.getCategoryField().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(pageSize); + for (Feature feature : anomalyDetector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = ParseUtils + .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); + composite.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + + // In order to optimize the early termination it is advised to set track_total_hits in the request to false. + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(internalFilterQuery) + .size(0) + .aggregation(composite) + .trackTotalHits(false); + + return new PageIterator(searchSourceBuilder); + } + + public class PageIterator { + private SearchSourceBuilder source; + // a map from categorical field name to values (type: java.lang.Comparable) + Map afterKey; + + public PageIterator(SearchSourceBuilder source) { + this.source = source; + this.afterKey = null; + } + + /** + * Results are returned using listener + * @param listener Listener to return results + */ + public void next(ActionListener listener) { + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); + client.search(searchRequest, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + processResponse(response, () -> client.search(searchRequest, this), listener); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void processResponse(SearchResponse response, Runnable retry, ActionListener listener) { + if (shouldRetryDueToEmptyPage(response)) { + updateCompositeAfterKey(response, source); + retry.run(); + return; + } + + try { + Page page = analyzePage(response); + if (totalResults < maxEntities && afterKey != null) { + updateCompositeAfterKey(response, source); + listener.onResponse(page); + } else { + listener.onResponse(null); + } + } catch (Exception ex) { + listener.onFailure(ex); + } + } + + /** + * + * @param response current response + * @return A page containing + * ** the after key + * ** query source builder to next page if any + * ** a map of composite keys to its values. The values are arranged + * according to the order of anomalyDetector.getEnabledFeatureIds(). + */ + private Page analyzePage(SearchResponse response) { + Optional compositeOptional = getComposite(response); + + if (false == compositeOptional.isPresent()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Empty resposne: %s", response)); + } + + CompositeAggregation composite = compositeOptional.get(); + Map results = new HashMap<>(); + /* + * + * Example composite aggregation: + * + "aggregations": { + "my_buckets": { + "after_key": { + "service": "app_6", + "host": "server_3" + }, + "buckets": [ + { + "key": { + "service": "app_6", + "host": "server_3" + }, + "doc_count": 1, + "the_max": { + "value": -38.0 + }, + "the_min": { + "value": -38.0 + } + } + ] + } + } + */ + for (Bucket bucket : composite.getBuckets()) { + Optional featureValues = parseBucket(bucket, anomalyDetector.getEnabledFeatureIds()); + // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" + if (featureValues.isPresent() && bucket.getKey() != null) { + results.put(Entity.createEntityByReordering(anomalyDetector.getDetectorId(), bucket.getKey()), featureValues.get()); + } + } + + totalResults += results.size(); + + afterKey = composite.afterKey(); + return new Page(results); + } + + private void updateCompositeAfterKey(SearchResponse r, SearchSourceBuilder search) { + Optional composite = getComposite(r); + + if (false == composite.isPresent()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Empty resposne: %s", r)); + } + + updateSourceAfterKey(composite.get().afterKey(), search); + } + + private void updateSourceAfterKey(Map afterKey, SearchSourceBuilder search) { + AggregationBuilder aggBuilder = search.aggregations().getAggregatorFactories().iterator().next(); + // update after-key with the new value + if (aggBuilder instanceof CompositeAggregationBuilder) { + CompositeAggregationBuilder comp = (CompositeAggregationBuilder) aggBuilder; + comp.aggregateAfter(afterKey); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid client request; expected a composite builder but instead got {}", aggBuilder) + ); + } + } + + private boolean shouldRetryDueToEmptyPage(SearchResponse response) { + Optional composite = getComposite(response); + // if there are no buckets but a next page, go fetch it instead of sending an empty response to the client + if (false == composite.isPresent()) { + return false; + } + CompositeAggregation aggr = composite.get(); + return aggr.getBuckets().isEmpty() && aggr.afterKey() != null && !aggr.afterKey().isEmpty(); + } + + Optional getComposite(SearchResponse response) { + if (response == null || response.getAggregations() == null) { + return Optional.empty(); + } + Aggregation agg = response.getAggregations().get(AGG_NAME_COMP); + if (agg == null) { + return Optional.empty(); + } + + if (agg instanceof CompositeAggregation) { + return Optional.of((CompositeAggregation) agg); + } + + throw new IllegalArgumentException(String.format(Locale.ROOT, "Not a composite response; {}", agg.getClass())); + } + + /** + * Whether next page exists. Conditions are: + * 1) we haven't fetched any page yet (totalResults == 0) or afterKey is not null + * 2) next detection interval has not started + * @return true if the iteration has more pages. + */ + public boolean hasNext() { + return (totalResults == 0 || (totalResults > 0 && afterKey != null)) && expirationEpochMs > clock.millis(); + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (afterKey != null) { + toStringBuilder.append("afterKey", afterKey); + } + if (source != null) { + toStringBuilder.append("source", source); + } + + return toStringBuilder.toString(); + } + } + + public class Page { + + Map results; + + public Page(Map results) { + this.results = results; + } + + public boolean isEmpty() { + return results == null || results.isEmpty(); + } + + public Map getResults() { + return results; + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (results != null) { + toStringBuilder.append("results", results); + } + + return toStringBuilder.toString(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/ad/feature/FeatureManager.java index f076c04a3..13793fe8b 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/ad/feature/FeatureManager.java @@ -586,13 +586,7 @@ void getSamplesInRangesForEntity( ActionListener>, double[][]>> listener ) throws IOException { searchFeatureDao - .getColdStartSamplesForPeriods( - detector, - sampleRanges, - entity.getValue(), - true, - getSamplesRangesListener(sampleRanges, listener) - ); + .getColdStartSamplesForPeriods(detector, sampleRanges, entity, true, getSamplesRangesListener(sampleRanges, listener)); } private ActionListener>> getSamplesRangesListener( diff --git a/src/main/java/org/opensearch/ad/feature/ScriptMaker.java b/src/main/java/org/opensearch/ad/feature/ScriptMaker.java new file mode 100644 index 000000000..ae9ab37b8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/ScriptMaker.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; + +public class ScriptMaker { + private static final String template = "\'%s\': doc['%s'].value"; + + /** + * We use composite aggregation for feature aggregation. But composite aggregation + * does not support ordering results based on doc count, which is required by + * preview and historical related components. We need to use terms aggregation. + * Terms aggregation does not support collecting terms from multiple fields in + * the same document. Scripts come to the rescue: With a script to retrieve terms + * from multiple fields, we can still use terms aggregation to partition data. + * The script disables the global ordinals optimization and will be slower than + * collecting terms from a single field. Still, it gives us the flexibility to + * implement this option at search time. For a simple example, consider the + * following query about the number field’s sum aggregation on buckets partitioned + * by category_field_1 and category_field_2 from index test. + * + * Query: + * GET /test/_search + { + "aggregations": { + "term_agg": { + "terms": { + "script": { + "source": "['category_field_1': doc['category_field_1'].value, + 'category_field_2': doc['category_field_2'].value]", + "lang": "painless" + } + }, + "aggregations": { + "sum_number": { + "sum": { + "field": "number" + } + } + } + } + } + } + * + * Result: + *"aggregations": { + "term_agg": { + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0, + "buckets": [ + { + "key": "{category_field_1=app_0, category_field_2=server_1}", + "doc_count": 1, + "sum_number": { + "value": 1449.0 + } + }, + { + "key": "{category_field_1=app_1, category_field_2=server_1}", + "doc_count": 1, + "sum_number": { + "value": 5200.0 + } + }, + ... + * + * I put two categorical field in a map for parsing the results. Otherwise, + * I won't know which categorical value is for which field. + * @param fields categorical fields + * @return script to use in terms aggregation + */ + public static Script makeTermsScript(List fields) { + StringBuffer format = new StringBuffer(); + // in painless, a map is sth like [a:b, c:d] + format.append("["); + for (int i = 0; i < fields.size(); i++) { + if (i > 0) { + format.append(","); + } + format.append(String.format(Locale.ROOT, template, fields.get(i), fields.get(i))); + } + format.append("]"); + return new Script(ScriptType.INLINE, "painless", format.toString(), Collections.emptyMap()); + } +} diff --git a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java index 2e70de334..9363f4602 100644 --- a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java @@ -29,18 +29,18 @@ import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; import static org.opensearch.ad.constant.CommonName.DATE_HISTOGRAM; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; import static org.opensearch.ad.util.ParseUtils.batchFeatureQuery; import java.io.IOException; +import java.lang.reflect.Type; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.AbstractMap.SimpleEntry; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; -import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -53,15 +53,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; -import org.opensearch.ad.common.exception.EndRunException; -import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.dataprocessor.Interpolator; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.Entity; -import org.opensearch.ad.model.Feature; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.ParseUtils; @@ -76,29 +71,30 @@ import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; import org.opensearch.search.aggregations.bucket.composite.InternalComposite; import org.opensearch.search.aggregations.bucket.range.InternalDateRange; import org.opensearch.search.aggregations.bucket.range.InternalDateRange.Bucket; import org.opensearch.search.aggregations.bucket.terms.Terms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.aggregations.metrics.Min; -import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; -import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; /** * DAO for features from search. */ -public class SearchFeatureDao { +public class SearchFeatureDao extends AbstractRetriever { protected static final String AGG_NAME_MIN = "min_timefield"; protected static final String AGG_NAME_TERM = "term_agg"; + private static final Type multiTermsAttributesType = new TypeToken>() { + }.getType(); + private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); // Dependencies @@ -106,9 +102,8 @@ public class SearchFeatureDao { private final NamedXContentRegistry xContent; private final Interpolator interpolator; private final ClientUtil clientUtil; - private ThreadPool threadPool; - private int maxEntitiesPerQuery; private int maxEntitiesForPreview; + private final Gson gson; /** * Constructor injection. @@ -117,28 +112,26 @@ public class SearchFeatureDao { * @param xContent ES XContentRegistry * @param interpolator interpolator for missing values * @param clientUtil utility for ES client - * @param threadPool accessor to different threadpools * @param settings ES settings * @param clusterService ES ClusterService + * @param gson Gson accessor */ public SearchFeatureDao( Client client, NamedXContentRegistry xContent, Interpolator interpolator, ClientUtil clientUtil, - ThreadPool threadPool, Settings settings, - ClusterService clusterService + ClusterService clusterService, + Gson gson ) { this.client = client; this.xContent = xContent; this.interpolator = interpolator; this.clientUtil = clientUtil; - this.threadPool = threadPool; - this.maxEntitiesPerQuery = MAX_ENTITIES_PER_QUERY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerQuery = it); this.maxEntitiesForPreview = MAX_ENTITIES_FOR_PREVIEW.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_FOR_PREVIEW, it -> maxEntitiesForPreview = it); + this.gson = gson; } /** @@ -197,10 +190,18 @@ public void getHighestCountEntities(AnomalyDetector detector, long startTime, lo .includeUpper(false); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().filter(rangeQuery).filter(detector.getFilterQuery()); - TermsAggregationBuilder termsAgg = AggregationBuilders - .terms(AGG_NAME_TERM) - .field(detector.getCategoryField().get(0)) - .size(maxEntitiesForPreview); + TermsAggregationBuilder termsAgg = AggregationBuilders.terms(AGG_NAME_TERM).size(maxEntitiesForPreview); + if (detector.getCategoryField() == null || detector.getCategoryField().isEmpty()) { + listener.onResponse(null); + return; + } + + if (detector.getCategoryField().size() == 1) { + termsAgg.field(detector.getCategoryField().get(0)); + } else { + termsAgg.script(ScriptMaker.makeTermsScript(detector.getCategoryField())); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(boolQueryBuilder) .aggregation(termsAgg) @@ -213,7 +214,6 @@ public void getHighestCountEntities(AnomalyDetector detector, long startTime, lo listener.onResponse(Collections.emptyList()); return; } - List results = aggs .asList() .stream() @@ -222,26 +222,53 @@ public void getHighestCountEntities(AnomalyDetector detector, long startTime, lo .map(bucket -> bucket.getKeyAsString()) .collect(Collectors.toList()) .stream() - .map(entityValue -> new Entity(detector.getCategoryField().get(0), entityValue)) + .map(entityValue -> parseCategoricalField(entityValue, detector)) .collect(Collectors.toList()); listener.onResponse(results); }, listener::onFailure); client.search(searchRequest, termsListener); } + /** + * Precondition: + * {@code detector != null && detector.getCategoryField().size() > 0 } + * + * @param entityValue the representation of the entity's attributes. For + * single-attribute entity, it is the value of the attribute like "server_1". + * For multi-attribute entity, it is a map of attribute names and values like + * " {service=app_0, host=server_1}" . + * @param detector Anomaly detector + * @return an Entity object corresponding to the entity + */ + private Entity parseCategoricalField(String entityValue, AnomalyDetector detector) { + List categoricalFields = detector.getCategoryField(); + if (categoricalFields.size() == 1) { + return Entity.createSingleAttributeEntity(detector.getDetectorId(), detector.getCategoryField().get(0), entityValue); + } + return Entity + .createEntityByReordering( + detector.getDetectorId(), + AccessController + .doPrivileged((PrivilegedAction>) () -> gson.fromJson(entityValue, multiTermsAttributesType)) + ); + } + /** * Get the entity's earliest and latest timestamps * @param detector detector config - * @param entityName entity's name + * @param entity the entity's information * @param listener listener to return back the requested timestamps */ public void getEntityMinMaxDataTime( AnomalyDetector detector, - String entityName, + Entity entity, ActionListener, Optional>> listener ) { - TermQueryBuilder term = new TermQueryBuilder(detector.getCategoryField().get(0), entityName); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(term); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(internalFilterQuery) @@ -352,21 +379,6 @@ private Optional parseResponse(SearchResponse response, List f return parseAggregations(Optional.ofNullable(response).map(resp -> resp.getAggregations()), featureIds); } - private double parseAggregation(Aggregation aggregation) { - Double result = null; - if (aggregation instanceof SingleValue) { - result = ((SingleValue) aggregation).value(); - } else if (aggregation instanceof InternalTDigestPercentiles) { - Iterator percentile = ((InternalTDigestPercentiles) aggregation).iterator(); - if (percentile.hasNext()) { - result = percentile.next().getValue(); - } - } - return Optional - .ofNullable(result) - .orElseThrow(() -> new EndRunException("Failed to parse aggregation " + aggregation, true).countedInStats(false)); - } - /** * Gets samples of features for the time ranges. * @@ -735,30 +747,14 @@ private SearchRequest createPreviewSearchRequest(AnomalyDetector detector, List< } } - private Optional parseBucket(InternalDateRange.Bucket bucket, List featureIds) { - return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); - } - - private Optional parseAggregations(Optional aggregations, List featureIds) { - return aggregations - .map(aggs -> aggs.asMap()) - .map( - map -> featureIds - .stream() - .mapToDouble(id -> Optional.ofNullable(map.get(id)).map(this::parseAggregation).orElse(Double.NaN)) - .toArray() - ) - .filter(result -> Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d))); - } - public void getColdStartSamplesForPeriods( AnomalyDetector detector, List> ranges, - String entityName, + Entity entity, boolean includesEmptyBucket, ActionListener>> listener ) throws IOException { - SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entityName); + SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entity); client.search(request, ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); @@ -795,85 +791,9 @@ public void getColdStartSamplesForPeriods( }, listener::onFailure)); } - /** - * Get features by entities. An entity is one combination of particular - * categorical fields’ value. A categorical field in this setting refers to - * an OpenSearch field of type keyword or ip. Specifically, an entity - * can be the IP address 182.3.4.5. - * @param detector Accessor to the detector object - * @param startMilli Start of time range to query - * @param endMilli End of time range to query - * @param listener Listener to return entities and their data points - */ - public void getFeaturesByEntities( - AnomalyDetector detector, - long startMilli, - long endMilli, - ActionListener> listener - ) { - try { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) - .gte(startMilli) - .lt(endMilli) - .format("epoch_millis"); - - BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(detector.getFilterQuery()).filter(rangeQuery); - - /* Terms aggregation implementation.*/ - // Support one category field - TermsAggregationBuilder termsAgg = AggregationBuilders - .terms(AGG_NAME_TERM) - .field(detector.getCategoryField().get(0)) - .size(maxEntitiesPerQuery); - for (Feature feature : detector.getFeatureAttributes()) { - AggregatorFactories.Builder internalAgg = ParseUtils - .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); - termsAgg.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); - } - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(internalFilterQuery) - .size(0) - .aggregation(termsAgg) - .trackTotalHits(false); - SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); - - ActionListener termsListener = ActionListener.wrap(response -> { - Aggregations aggs = response.getAggregations(); - if (aggs == null) { - listener.onResponse(Collections.emptyMap()); - return; - } - - Map results = aggs - .asList() - .stream() - .filter(agg -> AGG_NAME_TERM.equals(agg.getName())) - .flatMap(agg -> ((Terms) agg).getBuckets().stream()) - .collect(Collectors.toMap(Terms.Bucket::getKeyAsString, bucket -> parseBucket(bucket, detector.getEnabledFeatureIds()))) - .entrySet() - .stream() - .filter(entry -> entry.getValue().isPresent()) - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().get())); - - listener.onResponse(results); - }, listener::onFailure); - - client - .search( - searchRequest, - new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, termsListener, false) - ); - - } catch (Exception e) { - // TODO: catch concrete exception and check if they should be counted in stats or not - throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, false); - } - } - - private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, String entityName) { + private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, Entity entity) { try { - SearchSourceBuilder searchSourceBuilder = ParseUtils.generateEntityColdStartQuery(detector, ranges, entityName, xContent); + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateEntityColdStartQuery(detector, ranges, entity, xContent); return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); } catch (IOException e) { logger @@ -890,7 +810,8 @@ private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detect } } - private Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { + @Override + public Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); } } diff --git a/src/main/java/org/opensearch/ad/indices/AnomalyDetectionIndices.java b/src/main/java/org/opensearch/ad/indices/AnomalyDetectionIndices.java index 9b495be4c..4ccf02db6 100644 --- a/src/main/java/org/opensearch/ad/indices/AnomalyDetectionIndices.java +++ b/src/main/java/org/opensearch/ad/indices/AnomalyDetectionIndices.java @@ -26,7 +26,7 @@ package org.opensearch.ad.indices; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.ANOMALY_DETECTION_STATE_INDEX_MAPPING_FILE; @@ -160,7 +160,7 @@ public AnomalyDetectionIndices( this.threadPool = threadPool; this.clusterService.addLocalNodeMasterListener(this); this.historyRolloverPeriod = AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings); - this.historyMaxDocs = AD_RESULT_HISTORY_MAX_DOCS.get(settings); + this.historyMaxDocs = AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.get(settings); this.historyRetentionPeriod = AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings); this.maxPrimaryShards = MAX_PRIMARY_SHARDS.get(settings); @@ -171,7 +171,7 @@ public AnomalyDetectionIndices( this.allUpdated = false; this.updateRunning = new AtomicBoolean(false); - this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_RESULT_HISTORY_MAX_DOCS, it -> historyMaxDocs = it); + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, it -> historyMaxDocs = it); this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_RESULT_HISTORY_ROLLOVER_PERIOD, it -> { historyRolloverPeriod = it; @@ -370,13 +370,17 @@ private void choosePrimaryShards(CreateIndexRequest request) { Settings .builder() // put 1 primary shards per hot node if possible - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Math.min(nodeFilter.getNumberOfEligibleDataNodes(), maxPrimaryShards)) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, getNumberOfPrimaryShards()) // 1 replica for better search performance and fail-over .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) .put("index.hidden", true) ); } + private int getNumberOfPrimaryShards() { + return Math.min(nodeFilter.getNumberOfEligibleDataNodes(), maxPrimaryShards); + } + /** * Create anomaly result index without checking exist or not. * @@ -501,7 +505,7 @@ void rolloverAndDeleteHistoryIndex() { choosePrimaryShards(createRequest); - rollOverRequest.addMaxIndexDocsCondition(historyMaxDocs); + rollOverRequest.addMaxIndexDocsCondition(historyMaxDocs * getNumberOfPrimaryShards()); adminClient.indices().rolloverIndex(rollOverRequest, ActionListener.wrap(response -> { if (!response.isRolledOver()) { logger diff --git a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java b/src/main/java/org/opensearch/ad/ml/CheckpointDao.java index aae0e3974..2dcb1e5c6 100644 --- a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java +++ b/src/main/java/org/opensearch/ad/ml/CheckpointDao.java @@ -26,24 +26,20 @@ package org.opensearch.ad.ml; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; -import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.Arrays; -import java.util.ConcurrentModificationException; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.locks.ReentrantLock; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -51,7 +47,6 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.ActionListener; -import org.opensearch.action.DocWriteRequest; import org.opensearch.action.bulk.BulkAction; import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; @@ -60,13 +55,16 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.AnomalyDetectionIndices; -import org.opensearch.ad.util.BulkUtil; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.index.IndexNotFoundException; @@ -78,7 +76,6 @@ import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; -import com.google.common.util.concurrent.RateLimiter; import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; @@ -116,15 +113,12 @@ public class CheckpointDao { private Gson gson; private RandomCutForestSerDe rcfSerde; - private ConcurrentLinkedQueue> requests; - private final ReentrantLock lock; private final Class thresholdingModelClass; - private final Duration checkpointInterval; - private final Clock clock; + private final AnomalyDetectionIndices indexUtil; - private final RateLimiter bulkRateLimiter; - private final int maxBulkRequestSize; private final JsonParser parser = new JsonParser(); + // we won't read/write a checkpoint larger than a threshold + private final int maxCheckpointBytes; /** * Constructor with dependencies and configuration. @@ -135,11 +129,8 @@ public class CheckpointDao { * @param gson accessor to Gson functionality * @param rcfSerde accessor to rcf serialization/deserialization * @param thresholdingModelClass thresholding model's class - * @param clock a UTC clock - * @param checkpointInterval how often we should save a checkpoint * @param indexUtil Index utility methods - * @param maxBulkRequestSize max number of index request a bulk can contain - * @param bulkPerSecond bulk requests per second + * @param maxCheckpointBytes max checkpoint size in bytes */ public CheckpointDao( Client client, @@ -148,28 +139,17 @@ public CheckpointDao( Gson gson, RandomCutForestSerDe rcfSerde, Class thresholdingModelClass, - Clock clock, - Duration checkpointInterval, AnomalyDetectionIndices indexUtil, - int maxBulkRequestSize, - double bulkPerSecond + int maxCheckpointBytes ) { this.client = client; this.clientUtil = clientUtil; this.indexName = indexName; this.gson = gson; this.rcfSerde = rcfSerde; - this.requests = new ConcurrentLinkedQueue<>(); - this.lock = new ReentrantLock(); this.thresholdingModelClass = thresholdingModelClass; - this.clock = clock; - this.checkpointInterval = checkpointInterval; this.indexUtil = indexUtil; - // each checkpoint with model initialized is roughly 250 KB if we are using shingle size 1 with 1 feature - // 1k limit will send 250 KB * 1000 = 250 MB - this.maxBulkRequestSize = maxBulkRequestSize; - // 1 bulk request per 1/bulkPerSecond seconds. - this.bulkRateLimiter = RateLimiter.create(bulkPerSecond); + this.maxCheckpointBytes = maxCheckpointBytes; } /** @@ -250,122 +230,37 @@ private void saveModelCheckpointAsync(Map source, String modelId } /** - * Bulk writing model states prepared previously - */ - public void flush() { - try { - // in case that other threads are doing bulk as well. - if (!lock.tryLock()) { - return; - } - if (requests.size() > 0 && bulkRateLimiter.tryAcquire()) { - final BulkRequest bulkRequest = new BulkRequest(); - // at most 1000 index requests per bulk - for (int i = 0; i < maxBulkRequestSize; i++) { - DocWriteRequest req = requests.poll(); - if (req == null) { - break; - } - - bulkRequest.add(req); - } - if (indexUtil.doesCheckpointIndexExist()) { - flush(bulkRequest); - } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - flush(bulkRequest); - } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - flush(bulkRequest); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); - } - })); - } - } - } finally { - if (lock.isHeldByCurrentThread()) { - lock.unlock(); - } - } - } - - private void flush(BulkRequest bulkRequest) { - clientUtil.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(r -> { - if (r.hasFailures()) { - requests.addAll(BulkUtil.getIndexRequestToRetry(bulkRequest, r)); - } else if (requests.size() >= maxBulkRequestSize / 2) { - // during maintenance, we may have much more waiting in the queue. - // trigger another flush if that's the case. - flush(); - } - }, e -> { - logger.error("Failed bulking checkpoints", e); - // retry during next bulk. - for (DocWriteRequest req : bulkRequest.requests()) { - requests.add(req); - } - })); - } - - /** - * Prepare bulking the input model state to the checkpoint index. - * We don't save checkpoints within checkpointInterval again. - * @param modelState Model state - * @param modelId Model Id - */ - public void write(ModelState modelState, String modelId) { - write(modelState, modelId, false); - } - - /** - * Prepare bulking the input model state to the checkpoint index. - * We don't save checkpoints within checkpointInterval again, except this - * is from cold start. This method will update the input state's last - * checkpoint time if the checkpoint is staged (ready to be written in the - * next batch). - * @param modelState Model state - * @param modelId Model Id - * @param coldStart whether the checkpoint comes from cold start + * Prepare for index request using the contents of the given model state + * @param modelState an entity model state + * @return serialized JSON map or empty map if the state is too bloated + * @throws IOException when serialization fails */ - public void write(ModelState modelState, String modelId, boolean coldStart) { - Instant instant = modelState.getLastCheckpointTime(); - // Instant.MIN is the default value. We don't save until we are sure. - if ((instant == Instant.MIN || instant.plus(checkpointInterval).isAfter(clock.instant())) && !coldStart) { - return; + public Map toIndexSource(ModelState modelState) throws IOException { + Map source = new HashMap<>(); + EntityModel model = modelState.getModel(); + String serializedModel = toCheckpoint(model); + if (serializedModel == null || serializedModel.length() > maxCheckpointBytes) { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel == null ? 0 : serializedModel.length() + ) + ); + return source; } - // It is possible 2 states of the same model id gets saved: one overwrite another. - // This can happen if previous checkpoint hasn't been saved to disk, while the - // 1st one creates a new state without restoring. - if (modelState.getModel() != null) { - try { - // we can have ConcurrentModificationException when calling toCheckpoint - // and updating rcf model at the same time. To prevent this, - // we need to have a deep copy of models or have a lock. Both - // options are costly. - // As we are gonna retry serializing either when the entity is - // evicted out of cache or during the next maintenance period, - // don't do anything when the exception happens. - String serializedModel = toCheckpoint(modelState.getModel()); - Map source = new HashMap<>(); - source.put(DETECTOR_ID, modelState.getDetectorId()); - source.put(FIELD_MODEL, serializedModel); - source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); - requests.add(new IndexRequest(indexName).id(modelId).source(source)); - modelState.setLastCheckpointTime(clock.instant()); - if (requests.size() >= maxBulkRequestSize) { - flush(); - } - } catch (ConcurrentModificationException e) { - logger.info(new ParameterizedMessage("Concurrent modification while serializing models for [{}]", modelId), e); - } + String detectorId = modelState.getDetectorId(); + source.put(DETECTOR_ID, detectorId); + source.put(FIELD_MODEL, serializedModel); + source.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); + Optional entity = model.getEntity(); + if (entity.isPresent()) { + source.put(CommonName.ENTITY_KEY, entity.get()); } + + return source; } /** @@ -385,8 +280,12 @@ public Optional getModelCheckpoint(String modelId) { .map(source -> (String) source.get(FIELD_MODEL)); } - String toCheckpoint(EntityModel model) { + public String toCheckpoint(EntityModel model) { return AccessController.doPrivileged((PrivilegedAction) () -> { + if (model == null) { + logger.warn("Empty model"); + return null; + } JsonObject json = new JsonObject(); json.add(ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); if (model.getRcf() != null) { @@ -475,28 +374,52 @@ private void logFailure(BulkByScrollResponse response, String detectorID) { } } - private Entry fromEntityModelCheckpoint(Map checkpoint, String modelId) { + /** + * Load json checkpoint into models + * + * @param checkpoint json checkpoint contents + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time; or empty if + * the raw checkpoint is too large + */ + public Optional> fromEntityModelCheckpoint(Map checkpoint, String modelId) { try { - return AccessController.doPrivileged((PrivilegedAction>) () -> { + return AccessController.doPrivileged((PrivilegedAction>>) () -> { String model = (String) (checkpoint.get(FIELD_MODEL)); + if (model.length() > maxCheckpointBytes) { + logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); + return Optional.empty(); + } JsonObject json = parser.parse(model).getAsJsonObject(); + // verified, don't need privileged call to get permission ArrayDeque samples = new ArrayDeque<>( Arrays.asList(this.gson.fromJson(json.getAsJsonArray(ENTITY_SAMPLE), new double[0][0].getClass())) ); RandomCutForest rcf = null; if (json.has(ENTITY_RCF)) { + // verified, don't need privileged call to get permission rcf = rcfSerde.fromJson(json.getAsJsonPrimitive(ENTITY_RCF).getAsString()); } ThresholdingModel threshold = null; if (json.has(ENTITY_THRESHOLD)) { + // verified, don't need privileged call to get permission threshold = this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass); } String lastCheckpointTimeString = (String) (checkpoint.get(TIMESTAMP)); Instant timestamp = Instant.parse(lastCheckpointTimeString); - return new SimpleImmutableEntry<>(new EntityModel(modelId, samples, rcf, threshold), timestamp); + Entity entity = null; + Object serializedEntity = checkpoint.get(CommonName.ENTITY_KEY); + if (serializedEntity != null) { + try { + entity = Entity.fromJsonArray(serializedEntity); + } catch (Exception e) { + logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e); + } + } + return Optional.of(new SimpleImmutableEntry<>(new EntityModel(entity, samples, rcf, threshold), timestamp)); }); - } catch (RuntimeException e) { + } catch (Exception e) { logger.warn("Exception while deserializing checkpoint", e); throw e; } @@ -505,17 +428,30 @@ private Entry fromEntityModelCheckpoint(Map>> listener) { - clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { - Optional> checkpointString = processRawCheckpoint(response); - if (checkpointString.isPresent()) { - listener.onResponse(Optional.of(fromEntityModelCheckpoint(checkpointString.get(), modelId))); - } else { - listener.onResponse(Optional.empty()); - } - }, listener::onFailure)); + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> { listener.onResponse(processGetResponse(response, modelId)); }, listener::onFailure) + ); + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public Optional> processGetResponse(GetResponse response, String modelId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromEntityModelCheckpoint(checkpointString.get(), modelId); + } else { + return Optional.empty(); + } } /** @@ -544,4 +480,31 @@ private Optional processModelCheckpoint(GetResponse response) { private Optional> processRawCheckpoint(GetResponse response) { return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); } + + public void batchRead(MultiGetRequest request, ActionListener listener) { + clientUtil.execute(MultiGetAction.INSTANCE, request, listener); + } + + public void batchWrite(BulkRequest request, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + // create index failure. Notify callers using listener. + listener.onFailure(new RuntimeException("Creating checkpoint with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); + listener.onFailure(exception); + } + })); + } + } } diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java index 990dba310..8adb3b34c 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java @@ -36,11 +36,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.Queue; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import java.util.stream.DoubleStream; import java.util.stream.Stream; @@ -52,27 +52,31 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.DoorKeeper; import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.dataprocessor.Interpolator; import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.common.lease.Releasable; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; import com.amazon.randomcutforest.RandomCutForest; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; /** - * Training models for multi-entity detectors + * Training models for HCAD detectors * */ -public class EntityColdStarter { +public class EntityColdStarter implements MaintenanceState { private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); private final Clock clock; private final ThreadPool threadPool; @@ -94,9 +98,13 @@ public class EntityColdStarter { private final int shingleSize; private Instant lastThrottledColdStartTime; private final FeatureManager featureManager; - private final Cache lastColdStartTime; - private final CheckpointDao checkpointDao; private int coolDownMinutes; + // A bloom filter checked before cold start to ensure we don't repeatedly + // retry cold start of the same model. + // keys are detector ids. + private Map doorKeepers; + private final Duration modelTtl; + private final CheckpointWriteWorker checkpointWriteQueue; /** * Constructor @@ -121,10 +129,11 @@ public class EntityColdStarter { * @param thresholdDownsamples the number of samples to keep during downsampling * @param thresholdMaxSamples the max number of samples before downsampling * @param featureManager Used to create features for models. - * @param lastColdStartTimestampTtl max time to retain last cold start timestamp - * @param maxCacheSize max cache size - * @param checkpointDao utility to interact with the checkpoint index * @param settings ES settings accessor + * @param modelTtl time-to-live before last access time of the cold start cache. + * We have a cache to record entities that have run cold starts to avoid + * repeated unsuccessful cold start. + * @param checkpointWriteQueue queue to insert model checkpoints */ public EntityColdStarter( Clock clock, @@ -146,10 +155,9 @@ public EntityColdStarter( int thresholdDownsamples, long thresholdMaxSamples, FeatureManager featureManager, - Duration lastColdStartTimestampTtl, - long maxCacheSize, - CheckpointDao checkpointDao, - Settings settings + Settings settings, + Duration modelTtl, + CheckpointWriteWorker checkpointWriteQueue ) { this.clock = clock; this.lastThrottledColdStartTime = Instant.MIN; @@ -171,98 +179,148 @@ public EntityColdStarter( this.thresholdDownsamples = thresholdDownsamples; this.thresholdMaxSamples = thresholdMaxSamples; this.featureManager = featureManager; - - this.lastColdStartTime = CacheBuilder - .newBuilder() - .expireAfterAccess(lastColdStartTimestampTtl.toHours(), TimeUnit.HOURS) - .maximumSize(maxCacheSize) - .concurrencyLevel(1) - .build(); - this.checkpointDao = checkpointDao; this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); + this.doorKeepers = new ConcurrentHashMap<>(); + this.modelTtl = modelTtl; + this.checkpointWriteQueue = checkpointWriteQueue; } - /** - * Training model for an entity - * @param modelId model Id corresponding to the entity - * @param entityName the entity's name - * @param detectorId the detector Id corresponding to the entity - * @param modelState model state associated with the entity - */ - private void coldStart(String modelId, String entityName, String detectorId, ModelState modelState) { - // Rate limiting: if last cold start of the detector is not finished, we don't trigger another one. - if (nodeStateManager.isColdStartRunning(detectorId)) { - return; - } + private ActionListener> onGetDetector( + String modelId, + Entity entity, + String detectorId, + ModelState modelState, + ActionListener listener + ) { + return ActionListener.wrap(detectorOptional -> { + boolean earlyExit = true; + try { + if (false == detectorOptional.isPresent()) { + logger.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return; + } - // Won't retry cold start within one hour for an entity; if threadpool queue is full, won't retry within 5 minutes - // 5 minutes is derived by 1000 (threadpool queue size) / 4 (1 cold start per 4 seconds according to the Http logs - // experiment) = 250 seconds. - if (lastColdStartTime.getIfPresent(modelId) == null - && lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isBefore(clock.instant())) { + AnomalyDetector detector = detectorOptional.get(); - final Releasable coldStartFinishingCallback = nodeStateManager.markColdStartRunning(detectorId); + DoorKeeper doorKeeper = doorKeepers + .computeIfAbsent( + detectorId, + id -> { + // reset every 60 intervals + return new DoorKeeper( + AnomalyDetectorSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, + AnomalyDetectorSettings.DOOR_KEEPER_FAULSE_POSITIVE_RATE, + detector.getDetectionIntervalDuration().multipliedBy(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock + ); + } + ); - logger.debug("Trigger cold start for {}", modelId); + // Won't retry cold start within 60 intervals for an entity + if (doorKeeper.mightContain(modelId)) { + return; + } - ActionListener>> nestedListener = ActionListener.wrap(trainingData -> { - if (trainingData.isPresent()) { - List dataPoints = trainingData.get(); - // only train models if we have enough samples - if (hasEnoughSample(dataPoints, modelState) == false) { - combineTrainSamples(dataPoints, modelId, modelState); - } else { - trainModelFromDataSegments(dataPoints, modelId, modelState); + doorKeeper.put(modelId); + + ActionListener>> coldStartCallBack = ActionListener.wrap(trainingData -> { + try { + if (trainingData.isPresent()) { + List dataPoints = trainingData.get(); + // only train models if we have enough samples + if (hasEnoughSample(dataPoints, modelState) == false) { + combineTrainSamples(dataPoints, modelId, modelState); + } else { + trainModelFromDataSegments(dataPoints, entity, modelState); + } + logger.info("Succeeded in training entity: {}", modelId); + } else { + logger.info("Cannot get training data for {}", modelId); + } + } finally { + listener.onResponse(null); } - logger.info("Succeeded in training entity: {}", modelId); - } else { - logger.info("Cannot get training data for {}", modelId); - } - }, exception -> { - Throwable cause = Throwables.getRootCause(exception); - if (cause instanceof RejectedExecutionException) { - logger.error("too many requests"); - lastThrottledColdStartTime = Instant.now(); - } else if (cause instanceof AnomalyDetectionException || exception instanceof AnomalyDetectionException) { - // e.g., cannot find anomaly detector - nodeStateManager.setLastColdStartException(detectorId, (AnomalyDetectionException) exception); - } else { - logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + + }, exception -> { + try { + logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + Throwable cause = Throwables.getRootCause(exception); + if (ExceptionUtil.isOverloaded(cause)) { + logger.error("too many requests"); + lastThrottledColdStartTime = Instant.now(); + } else if (cause instanceof AnomalyDetectionException || exception instanceof AnomalyDetectionException) { + // e.g., cannot find anomaly detector + nodeStateManager.setException(detectorId, exception); + } else { + nodeStateManager.setException(detectorId, new AnomalyDetectionException(detectorId, cause)); + } + } finally { + listener.onFailure(exception); + } + }); + + threadPool + .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> getEntityColdStartData( + detectorId, + entity, + shingleSize, + new ThreadedActionListener<>( + logger, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + coldStartCallBack, + false + ) + ) + ); + earlyExit = false; + } finally { + if (earlyExit) { + listener.onResponse(null); } - }); + } - final ActionListener>> listenerWithReleaseCallback = ActionListener - .runAfter(nestedListener, coldStartFinishingCallback::close); + }, exception -> { + logger.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); + listener.onFailure(exception); + }); + } - threadPool - .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> getEntityColdStartData( - detectorId, - entityName, - shingleSize, - new ThreadedActionListener<>( - logger, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, - listenerWithReleaseCallback, - false - ) - ) - ); + /** + * Training model for an entity + * @param modelId model Id corresponding to the entity + * @param entity the entity's information + * @param detectorId the detector Id corresponding to the entity + * @param modelState model state associated with the entity + * @param listener call back to call after cold start + */ + private void coldStart( + String modelId, + Entity entity, + String detectorId, + ModelState modelState, + ActionListener listener + ) { + logger.debug("Trigger cold start for {}", modelId); - lastColdStartTime.put(modelId, Instant.now()); + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + listener.onResponse(null); + return; } + + nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(modelId, entity, detectorId, modelState, listener)); } /** * Train model using given data points. * * @param dataPoints List of continuous data points, in ascending order of timestamps - * @param modelId The model Id + * @param entity Entity instance * @param entityState Entity state associated with the model Id */ - private void trainModelFromDataSegments(List dataPoints, String modelId, ModelState entityState) { + private void trainModelFromDataSegments(List dataPoints, Entity entity, ModelState entityState) { if (dataPoints == null || dataPoints.size() == 0 || dataPoints.get(0) == null || dataPoints.get(0).length == 0) { throw new IllegalArgumentException("Data points must not be empty."); } @@ -281,14 +339,14 @@ private void trainModelFromDataSegments(List dataPoints, String mode int totalLength = 0; // get continuous data points and send for training for (double[][] continuousDataPoints : dataPoints) { - double[] scores = trainRCFModel(continuousDataPoints, modelId, rcf); + double[] scores = trainRCFModel(continuousDataPoints, rcf); allScores.add(scores); totalLength += scores.length; } EntityModel model = entityState.getModel(); if (model == null) { - model = new EntityModel(modelId, new ArrayDeque<>(), null, null); + model = new EntityModel(entity, new ArrayDeque<>(), null, null); } model.setRcf(rcf); double[] joinedScores = new double[totalLength]; @@ -314,17 +372,16 @@ private void trainModelFromDataSegments(List dataPoints, String mode entityState.setLastUsedTime(clock.instant()); // save to checkpoint - checkpointDao.write(entityState, modelId, true); + checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); } /** * Train the RCF model using given data points * @param dataPoints Data points - * @param modelId The model Id * @param rcf RCF model to be trained * @return scores returned by RCF models */ - private double[] trainRCFModel(double[][] dataPoints, String modelId, RandomCutForest rcf) { + private double[] trainRCFModel(double[][] dataPoints, RandomCutForest rcf) { if (dataPoints.length == 0 || dataPoints[0].length == 0) { throw new IllegalArgumentException("Data points must not be empty."); } @@ -349,20 +406,19 @@ private double[] trainRCFModel(double[][] dataPoints, String modelId, RandomCutF * points to shingles. Finally, full shingles will be used for cold start. * * @param detectorId detector Id - * @param entityName entity's name + * @param entity the entity's information * @param entityShingleSize model's shingle size * @param listener listener to return training data */ private void getEntityColdStartData( String detectorId, - String entityName, + Entity entity, int entityShingleSize, ActionListener>> listener ) { ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { if (!detectorOp.isPresent()) { - nodeStateManager - .setLastColdStartException(detectorId, new EndRunException(detectorId, "AnomalyDetector is not available.", true)); + nodeStateManager.setException(detectorId, new EndRunException(detectorId, "AnomalyDetector is not available.", true)); return; } List coldStartData = new ArrayList<>(); @@ -432,7 +488,7 @@ private void getEntityColdStartData( .getColdStartSamplesForPeriods( detector, sampleRanges, - entityName, + entity, false, new ThreadedActionListener<>( logger, @@ -452,7 +508,7 @@ private void getEntityColdStartData( searchFeatureDao .getEntityMinMaxDataTime( detector, - entityName, + entity, new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, minMaxTimeListener, false) ); @@ -496,25 +552,40 @@ private List> getTrainSampleRanges( /** * Train models for the given entity - * @param samples Recent sample history - * @param modelId Model Id - * @param entityName The entity's name + * @param entity The entity info * @param detectorId Detector Id * @param modelState Model state associated with the entity + * @param listener callback before the method returns whenever EntityColdStarter + * finishes training or encounters exceptions. The listener helps notify the + * cold start queue to pull another request (if any) to execute. */ - public void trainModel( - Queue samples, - String modelId, - String entityName, - String detectorId, - ModelState modelState - ) { + public void trainModel(Entity entity, String detectorId, ModelState modelState, ActionListener listener) { + Queue samples = modelState.getModel().getSamples(); + String modelId = modelState.getModelId(); + if (samples.size() < this.numMinSamples) { // we cannot get last RCF score since cold start happens asynchronously - coldStart(modelId, entityName, detectorId, modelState); + coldStart(modelId, entity, detectorId, modelState, listener); } else { + try { + double[][] trainData = featureManager.batchShingle(samples.toArray(new double[0][0]), this.shingleSize); + trainModelFromDataSegments(Collections.singletonList(trainData), entity, modelState); + } finally { + listener.onResponse(null); + } + } + } + + public void trainModelFromExistingSamples(ModelState modelState) { + if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) { + return; + } + + EntityModel model = modelState.getModel(); + Queue samples = model.getSamples(); + if (samples.size() >= this.numMinSamples) { double[][] trainData = featureManager.batchShingle(samples.toArray(new double[0][0]), this.shingleSize); - trainModelFromDataSegments(Collections.singletonList(trainData), modelId, modelState); + trainModelFromDataSegments(Collections.singletonList(trainData), model.getEntity().orElse(null), modelState); } } @@ -553,7 +624,7 @@ private boolean hasEnoughSample(List dataPoints, ModelState coldstartDatapoints, String modelId, ModelState entityState) { EntityModel model = entityState.getModel(); if (model == null) { - model = new EntityModel(modelId, new ArrayDeque<>(), null, null); + model = new EntityModel(null, new ArrayDeque<>(), null, null); } for (double[][] consecutivePoints : coldstartDatapoints) { for (int i = 0; i < consecutivePoints.length; i++) { @@ -561,6 +632,19 @@ private void combineTrainSamples(List coldstartDatapoints, String mo } } // save to checkpoint - checkpointDao.write(entityState, modelId, true); + checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); + } + + @Override + public void maintenance() { + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String detectorId = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + if (doorKeeper.expired(modelTtl)) { + doorKeepers.remove(detectorId); + } else { + doorKeeper.maintenance(); + } + }); } } diff --git a/src/main/java/org/opensearch/ad/ml/EntityModel.java b/src/main/java/org/opensearch/ad/ml/EntityModel.java index 2a3c5330b..5159708b2 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityModel.java +++ b/src/main/java/org/opensearch/ad/ml/EntityModel.java @@ -26,26 +26,34 @@ package org.opensearch.ad.ml; +import java.util.Optional; import java.util.Queue; +import org.opensearch.ad.model.Entity; + import com.amazon.randomcutforest.RandomCutForest; public class EntityModel { - private String modelId; + private Entity entity; // TODO: sample should record timestamp private Queue samples; private RandomCutForest rcf; private ThresholdingModel threshold; - public EntityModel(String modelId, Queue samples, RandomCutForest rcf, ThresholdingModel threshold) { - this.modelId = modelId; + public EntityModel(Entity entity, Queue samples, RandomCutForest rcf, ThresholdingModel threshold) { + this.entity = entity; this.samples = samples; this.rcf = rcf; this.threshold = threshold; } - public String getModelId() { - return this.modelId; + /** + * In old checkpoint mapping, we don't have entity. It's fine we are missing + * entity as it is mostly used for debugging. + * @return entity + */ + public Optional getEntity() { + return Optional.ofNullable(entity); } public Queue getSamples() { diff --git a/src/main/java/org/opensearch/ad/ml/ModelManager.java b/src/main/java/org/opensearch/ad/ml/ModelManager.java index 6f10ed4ca..1f5bb51a6 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ModelManager.java @@ -61,6 +61,7 @@ import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.ml.rcf.CombinedRcfResult; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.returntypes.DiVector; @@ -149,7 +150,7 @@ public String getName() { * @param minPreviewSize minimum number of data points for preview * @param modelTtl time to live for hosted models * @param checkpointInterval interval between checkpoints - * @param entityColdStarter Used train models on input data + * @param entityColdStarter HCAD cold start utility * @param modelPartitioner Used to partition RCF models * @param featureManager Used to create features for models * @param memoryTracker AD memory usage tracker @@ -178,7 +179,6 @@ public ModelManager( FeatureManager featureManager, MemoryTracker memoryTracker ) { - this.rcfSerde = rcfSerde; this.checkpointDao = checkpointDao; this.gson = gson; @@ -944,102 +944,123 @@ public void getTotalUpdates(String modelId, String detectorId, ActionListener modelState, - String modelId + String modelId, + AnomalyDetector detector, + Entity entity ) { - ThresholdingResult result = null; - if (modelState != null) { - EntityModel model = modelState.getModel(); - Queue samples = model.getSamples(); - samples.add(datapoint); - if (samples.size() > this.rcfNumMinSamples) { - samples.remove(); + EntityModel entityModel = modelState.getModel(); + + if (entityModel == null) { + entityModel = new EntityModel(entity, new ArrayDeque<>(), null, null); + modelState.setModel(entityModel); } - result = maybeTrainBeforeScore(modelState, entityName); + // trainModelFromExistingSamples may be able to make models not null + if (entityModel.getRcf() == null || entityModel.getThreshold() == null) { + entityColdStarter.trainModelFromExistingSamples(modelState); + } + + if (entityModel.getRcf() != null && entityModel.getThreshold() != null) { + return score(datapoint, modelId, modelState); + } else { + entityModel.addSample(datapoint); + return new ThresholdingResult(0, 0, 0); + } } else { - result = new ThresholdingResult(0, 0, 0); + return new ThresholdingResult(0, 0, 0); } - - return result; } - private ThresholdingResult score(Queue samples, String modelId, ModelState modelState) { + public ThresholdingResult score(double[] feature, String modelId, ModelState modelState) { EntityModel model = modelState.getModel(); + if (model == null) { + return new ThresholdingResult(0, 0, 0); + } RandomCutForest rcf = model.getRcf(); ThresholdingModel threshold = model.getThreshold(); + if (rcf == null || threshold == null) { + return new ThresholdingResult(0, 0, 0); + } - double lastRcfScore = 0; - while (samples.peek() != null) { - double[] feature = samples.poll(); - lastRcfScore = rcf.getAnomalyScore(feature); - rcf.update(feature); - threshold.update(lastRcfScore); + // clear feature not scored yet + Queue samples = model.getSamples(); + while (samples != null && samples.peek() != null) { + double[] recordedFeature = samples.poll(); + double rcfScore = rcf.getAnomalyScore(recordedFeature); + rcf.update(recordedFeature); + threshold.update(rcfScore); } - double anomalyGrade = threshold.grade(lastRcfScore); + double rcfScore = rcf.getAnomalyScore(feature); + rcf.update(feature); + threshold.update(rcfScore); + + double anomalyGrade = threshold.grade(rcfScore); double anomalyConfidence = computeRcfConfidence(rcf) * threshold.confidence(); - ThresholdingResult result = new ThresholdingResult(anomalyGrade, anomalyConfidence, lastRcfScore); + ThresholdingResult result = new ThresholdingResult(anomalyGrade, anomalyConfidence, rcfScore); modelState.setLastUsedTime(clock.instant()); return result; } /** - * Create model Id out of detector Id and entity name - * @param detectorId Detector Id - * @param entityValue Entity's value - * @return The model Id - */ - public String getEntityModelId(String detectorId, String entityValue) { - return detectorId + "_entity_" + entityValue; - } - - /** - * Instantiate an entity state out of checkpoint. Running cold start if the - * model is empty. Update models using recent samples if applicable. + * Instantiate an entity state out of checkpoint. Train models if there are + * enough samples. * @param checkpoint Checkpoint loaded from index + * @param entity objects to access Entity attributes * @param modelId Model Id - * @param entityName Entity's name - * @param modelState entity state to instantiate + * @param detectorId Detector Id + * + * @return updated model state + * */ - public void processEntityCheckpoint( + public ModelState processEntityCheckpoint( Optional> checkpoint, + Entity entity, String modelId, - String entityName, - ModelState modelState + String detectorId ) { + // entity state to instantiate + ModelState modelState = new ModelState<>( + new EntityModel(entity, new ArrayDeque<>(), null, null), + modelId, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + if (checkpoint.isPresent()) { Entry modelToTime = checkpoint.get(); EntityModel restoredModel = modelToTime.getKey(); combineSamples(modelState.getModel(), restoredModel); modelState.setModel(restoredModel); modelState.setLastCheckpointTime(modelToTime.getValue()); - } else { - // the time controls whether we saves this state or not. - // if it is within one hour, we don't save. - // This branch means this is the first state in record - // (the checkpoint might have been deleted). - // we have to save. - modelState.setLastCheckpointTime(clock.instant().minus(checkpointInterval)); + } + EntityModel model = modelState.getModel(); + if (model == null) { + model = new EntityModel(null, new ArrayDeque<>(), null, null); + modelState.setModel(model); } - if (modelState.getModel() == null) { - modelState.setModel(new EntityModel(modelId, new ArrayDeque<>(), null, null)); + if ((model.getRcf() == null || model.getThreshold() == null) + && model.getSamples() != null + && model.getSamples().size() >= rcfNumMinSamples) { + entityColdStarter.trainModelFromExistingSamples(modelState); } - maybeTrainBeforeScore(modelState, entityName); + return modelState; } private void combineSamples(EntityModel fromModel, EntityModel toModel) { @@ -1048,28 +1069,4 @@ private void combineSamples(EntityModel fromModel, EntityModel toModel) { toModel.addSample(samples.poll()); } } - - /** - * Infer whenever both models are not null and do cold start if one of the models is not there - * @param modelState Model State - * @param entityName The entity's name - * @return model inference result for the entity, return all 0 Thresholding - * result if the models are not ready - */ - private ThresholdingResult maybeTrainBeforeScore(ModelState modelState, String entityName) { - EntityModel model = modelState.getModel(); - Queue samples = model.getSamples(); - String modelId = model.getModelId(); - String detectorId = modelState.getDetectorId(); - ThresholdingResult result = null; - if (model.getRcf() == null || model.getThreshold() == null) { - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); - } - - // update models using recent samples - if (model.getRcf() != null && model.getThreshold() != null && result == null) { - return score(samples, modelId, modelState); - } - return new ThresholdingResult(0, 0, 0); - } } diff --git a/src/main/java/org/opensearch/ad/ml/ModelState.java b/src/main/java/org/opensearch/ad/ml/ModelState.java index e77a3b5f1..44cee8f98 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelState.java +++ b/src/main/java/org/opensearch/ad/ml/ModelState.java @@ -33,19 +33,17 @@ import java.util.Map; import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.constant.CommonName; /** * A ML model and states such as usage. */ public class ModelState implements ExpiringState { - public static String MODEL_ID_KEY = "model_id"; - public static String DETECTOR_ID_KEY = "detector_id"; public static String MODEL_TYPE_KEY = "model_type"; public static String LAST_USED_TIME_KEY = "last_used_time"; public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time"; - public static String PRIORITY = "priority"; - + public static String PRIORITY_KEY = "priority"; private T model; private String modelId; private String detectorId; @@ -196,12 +194,27 @@ public void setPriority(float priority) { public Map getModelStateAsMap() { return new HashMap() { { - put(MODEL_ID_KEY, modelId); - put(DETECTOR_ID_KEY, detectorId); + put(CommonName.MODEL_ID_KEY, modelId); + put(CommonName.DETECTOR_ID_KEY, detectorId); put(MODEL_TYPE_KEY, modelType); - put(LAST_USED_TIME_KEY, lastUsedTime); - put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime); - put(PRIORITY, priority); + /* A stats API broadcasts requests to all nodes and renders node responses using toXContent. + * + * For the local node, the stats API's calls toXContent on the node response directly. + * For remote node, the coordinating node gets a serialized content from + * ADStatsNodeResponse.writeTo, deserializes the content, and renders the result using toXContent. + * Since ADStatsNodeResponse.writeTo uses StreamOutput::writeGenericValue, we can only use + * a long instead of the Instant object itself as + * StreamOutput::writeGenericValue only recognizes built-in types.*/ + put(LAST_USED_TIME_KEY, lastUsedTime.toEpochMilli()); + if (lastCheckpointTime != Instant.MIN) { + put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime.toEpochMilli()); + } + if (model != null && model instanceof EntityModel) { + EntityModel summary = (EntityModel) model; + if (summary.getEntity().isPresent()) { + put(CommonName.ENTITY_KEY, summary.getEntity().get().toStat()); + } + } } }; } diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index 773b45883..4bbbf40e2 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -26,7 +26,6 @@ package org.opensearch.ad.model; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CATEGORY_FIELD_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEFAULT_MULTI_ENTITY_SHINGLE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -46,6 +45,7 @@ import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.settings.NumericSetting; import org.opensearch.ad.util.ParseUtils; import org.opensearch.common.ParseField; import org.opensearch.common.io.stream.StreamInput; @@ -214,8 +214,9 @@ public AnomalyDetector( if (shingleSize != null && shingleSize < 1) { throw new IllegalArgumentException("Shingle size must be a positive integer"); } - if (categoryFields != null && categoryFields.size() > CATEGORY_FIELD_LIMIT) { - throw new IllegalArgumentException(CommonErrorMessages.CATEGORICAL_FIELD_NUMBER_SURPASSED + CATEGORY_FIELD_LIMIT); + int maxCategoryFields = NumericSetting.maxCategoricalFields(); + if (categoryFields != null && categoryFields.size() > maxCategoryFields) { + throw new IllegalArgumentException(CommonErrorMessages.getTooManyCategoricalFieldErr(maxCategoryFields)); } if (((IntervalTimeConfiguration) detectionInterval).getInterval() <= 0) { throw new IllegalArgumentException("Detection interval must be a positive integer"); diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index d49c701a3..0fc962919 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -76,6 +76,7 @@ public class AnomalyResult implements ToXContentObject, Writeable { public static final String ENTITY_FIELD = "entity"; public static final String USER_FIELD = "user"; public static final String TASK_ID_FIELD = "task_id"; + public static final String MODEL_ID_FIELD = "model_id"; private final String detectorId; private final String taskId; @@ -88,9 +89,20 @@ public class AnomalyResult implements ToXContentObject, Writeable { private final Instant executionStartTime; private final Instant executionEndTime; private final String error; - private final List entity; + private final Entity entity; private User user; private final Integer schemaVersion; + /* + * model id for easy aggregations of entities. The front end needs to query + * for entities ordered by the descending order of anomaly grades and the + * number of anomalies. After supporting multi-category fields, it is hard + * to write such queries since the entity information is stored in a nested + * object array. Also, the front end has all code/queries/ helper functions + * in place to rely on a single key per entity combo. This PR adds model id + * to anomaly result to help the transition to multi-categorical field less + * painful. + */ + private final String modelId; public AnomalyResult( String detectorId, @@ -134,7 +146,7 @@ public AnomalyResult( Instant executionStartTime, Instant executionEndTime, String error, - List entity, + Entity entity, User user, Integer schemaVersion ) { @@ -152,7 +164,8 @@ public AnomalyResult( error, entity, user, - schemaVersion + schemaVersion, + null ); } @@ -168,9 +181,10 @@ public AnomalyResult( Instant executionStartTime, Instant executionEndTime, String error, - List entity, + Entity entity, User user, - Integer schemaVersion + Integer schemaVersion, + String modelId ) { this.detectorId = detectorId; this.taskId = taskId; @@ -186,6 +200,7 @@ public AnomalyResult( this.entity = entity; this.user = user; this.schemaVersion = schemaVersion; + this.modelId = modelId; } public AnomalyResult(StreamInput input) throws IOException { @@ -204,11 +219,7 @@ public AnomalyResult(StreamInput input) throws IOException { this.executionEndTime = input.readInstant(); this.error = input.readOptionalString(); if (input.readBoolean()) { - int entitySize = input.readVInt(); - this.entity = new ArrayList<>(entitySize); - for (int i = 0; i < entitySize; i++) { - entity.add(new Entity(input)); - } + this.entity = new Entity(input); } else { this.entity = null; } @@ -219,6 +230,7 @@ public AnomalyResult(StreamInput input) throws IOException { } this.schemaVersion = input.readInt(); this.taskId = input.readOptionalString(); + this.modelId = input.readOptionalString(); } @Override @@ -254,7 +266,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(ERROR_FIELD, error); } if (entity != null) { - xContentBuilder.field(ENTITY_FIELD, entity.toArray()); + xContentBuilder.field(ENTITY_FIELD, entity); } if (user != null) { xContentBuilder.field(USER_FIELD, user); @@ -262,6 +274,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (taskId != null) { xContentBuilder.field(TASK_ID_FIELD, taskId); } + if (modelId != null) { + xContentBuilder.field(MODEL_ID_FIELD, modelId); + } return xContentBuilder.endObject(); } @@ -276,10 +291,11 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { Instant executionStartTime = null; Instant executionEndTime = null; String error = null; - List entityList = null; + Entity entity = null; User user = null; Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; String taskId = null; + String modelId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -321,11 +337,7 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { error = parser.text(); break; case ENTITY_FIELD: - entityList = new ArrayList<>(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - entityList.add(Entity.parse(parser)); - } + entity = Entity.parse(parser); break; case USER_FIELD: user = User.parse(parser); @@ -336,6 +348,9 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { case TASK_ID_FIELD: taskId = parser.text(); break; + case MODEL_ID_FIELD: + modelId = parser.text(); + break; default: parser.skipChildren(); break; @@ -353,9 +368,10 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { executionStartTime, executionEndTime, error, - entityList, + entity, user, - schemaVersion + schemaVersion, + modelId ); } @@ -378,7 +394,8 @@ public boolean equals(Object o) { && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) && Objects.equal(getError(), that.getError()) - && Objects.equal(getEntity(), that.getEntity()); + && Objects.equal(getEntity(), that.getEntity()) + && Objects.equal(getModelId(), that.getModelId()); } @Generated @@ -397,7 +414,8 @@ public int hashCode() { getExecutionStartTime(), getExecutionEndTime(), getError(), - getEntity() + getEntity(), + getModelId() ); } @@ -417,6 +435,7 @@ public String toString() { .append("executionEndTime", executionEndTime) .append("error", error) .append("entity", entity) + .append("modelId", modelId) .toString(); } @@ -464,10 +483,25 @@ public String getError() { return error; } - public List getEntity() { + public Entity getEntity() { return entity; } + public String getModelId() { + return modelId; + } + + /** + * Anomaly result index consists of overwhelmingly (99.5%) zero-grade non-error documents. + * This function exclude the majority case. + * @return whether the anomaly result is important when the anomaly grade is not 0 + * or error is there. + */ + public boolean isHighPriority() { + // AnomalyResult.toXContent won't record Double.NaN and thus make it null + return (getAnomalyGrade() != null && getAnomalyGrade() > 0) || getError() != null; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(detectorId); @@ -485,10 +519,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(error); if (entity != null) { out.writeBoolean(true); - out.writeVInt(entity.size()); - for (Entity entityItem : entity) { - entityItem.writeTo(out); - } + entity.writeTo(out); } else { out.writeBoolean(false); } @@ -500,5 +531,6 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeInt(schemaVersion); out.writeOptionalString(taskId); + out.writeOptionalString(modelId); } } diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java b/src/main/java/org/opensearch/ad/model/DetectorProfile.java index b0d2c2fba..e6ba6d479 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorProfile.java +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java @@ -42,7 +42,7 @@ public class DetectorProfile implements Writeable, ToXContentObject, Mergeable { private DetectorState state; private String error; - private ModelProfile[] modelProfile; + private ModelProfileOnNode[] modelProfile; private int shingleSize; private String coordinatingNode; private long totalSizeInBytes; @@ -61,7 +61,7 @@ public DetectorProfile(StreamInput in) throws IOException { } this.error = in.readOptionalString(); - this.modelProfile = in.readOptionalArray(ModelProfile::new, ModelProfile[]::new); + this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); this.shingleSize = in.readOptionalInt(); this.coordinatingNode = in.readOptionalString(); this.totalSizeInBytes = in.readOptionalLong(); @@ -78,7 +78,7 @@ private DetectorProfile() {} public static class Builder { private DetectorState state = null; private String error = null; - private ModelProfile[] modelProfile = null; + private ModelProfileOnNode[] modelProfile = null; private int shingleSize = -1; private String coordinatingNode = null; private long totalSizeInBytes = -1; @@ -99,7 +99,7 @@ public Builder error(String error) { return this; } - public Builder modelProfile(ModelProfile[] modelProfile) { + public Builder modelProfile(ModelProfileOnNode[] modelProfile) { this.modelProfile = modelProfile; return this; } @@ -196,7 +196,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } if (modelProfile != null && modelProfile.length > 0) { xContentBuilder.startArray(CommonName.MODELS); - for (ModelProfile profile : modelProfile) { + for (ModelProfileOnNode profile : modelProfile) { profile.toXContent(xContentBuilder, params); } xContentBuilder.endArray(); @@ -241,11 +241,11 @@ public void setError(String error) { this.error = error; } - public ModelProfile[] getModelProfile() { + public ModelProfileOnNode[] getModelProfile() { return modelProfile; } - public void setModelProfile(ModelProfile[] modelProfile) { + public void setModelProfile(ModelProfileOnNode[] modelProfile) { this.modelProfile = modelProfile; } diff --git a/src/main/java/org/opensearch/ad/model/Entity.java b/src/main/java/org/opensearch/ad/model/Entity.java index f7d76b145..814e71a3f 100644 --- a/src/main/java/org/opensearch/ad/model/Entity.java +++ b/src/main/java/org/opensearch/ad/model/Entity.java @@ -29,66 +29,181 @@ import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.SortedMap; +import java.util.TreeMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.SetOnce; import org.opensearch.ad.annotation.Generated; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.Numbers; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.hash.MurmurHash3; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentParser.Token; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.query.TermQueryBuilder; +import com.google.common.base.Joiner; import com.google.common.base.Objects; /** * Categorical field name and its value - * @author kaituo * */ public class Entity implements ToXContentObject, Writeable { - public static final String ENTITY_NAME_FIELD = "name"; - public static final String ENTITY_VALUE_FIELD = "value"; + private static final Logger LOG = LogManager.getLogger(Entity.class); - private final String name; - private final String value; + private static final long RANDOM_SEED = 42; + private static final String MODEL_ID_INFIX = "_entity_"; - public Entity(String name, String value) { - this.name = name; - this.value = value; + public static final String ATTRIBUTE_NAME_FIELD = "name"; + public static final String ATTRIBUTE_VALUE_FIELD = "value"; + + // model id + private SetOnce modelId = new SetOnce<>(); + // a map from attribute name like "host" to its value like "server_1" + // Use SortedMap so that the attributes are ordered and we can derive the unique + // string representation used in the hash ring. + private final SortedMap attributes; + + /** + * Create an entity that has multiple attributes + * @param detectorId Detector Id + * @param attrs what we parsed from query output as a map of attribute and its values. + * @return the created entity + */ + public static Entity createEntityByReordering(String detectorId, Map attrs) { + SortedMap sortedMap = new TreeMap<>(); + for (Map.Entry categoryValuePair : attrs.entrySet()) { + sortedMap.put(categoryValuePair.getKey(), categoryValuePair.getValue().toString()); + } + return new Entity(sortedMap); + } + + /** + * Create an entity that has only one attribute + * @param detectorId Detector Id + * @param attributeName the attribute's name + * @param attributeVal the attribute's value + * @return the created entity + */ + public static Entity createSingleAttributeEntity(String detectorId, String attributeName, String attributeVal) { + SortedMap sortedMap = new TreeMap<>(); + sortedMap.put(attributeName, attributeVal); + return new Entity(sortedMap); + } + + /** + * Create an entity from ordered attributes based on attribute names + * @param detectorId Detector Id + * @param attrs attribute map + * @return the created entity + */ + public static Entity createEntityFromOrderedMap(String detectorId, SortedMap attrs) { + return new Entity(attrs); + } + + private Entity(SortedMap orderedAttrs) { + this.attributes = orderedAttrs; } public Entity(StreamInput input) throws IOException { - this.name = input.readString(); - this.value = input.readString(); + this.attributes = new TreeMap<>(input.readMap(StreamInput::readString, StreamInput::readString)); } + /** + * Formatter when serializing to json. Used in cases when saving anomaly result for HCAD. + * The order is Alphabetical sorting (the one used by JDK to compare Strings). + * Example: + * z0 + * z11 + * z2 + */ @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - XContentBuilder xContentBuilder = builder.startObject().field(ENTITY_NAME_FIELD, name).field(ENTITY_VALUE_FIELD, value); - return xContentBuilder.endObject(); + builder.startArray(); + for (Map.Entry attr : attributes.entrySet()) { + builder.startObject().field(ATTRIBUTE_NAME_FIELD, attr.getKey()).field(ATTRIBUTE_VALUE_FIELD, attr.getValue()).endObject(); + } + builder.endArray(); + return builder; + } + + /** + * Return a map representing the entity, used in the stats API. + * + * A stats API broadcasts requests to all nodes and renders node responses using toXContent. + * + * For the local node, the stats API's calls toXContent on the node response directly. + * For remote node, the coordinating node gets a serialized content from + * ADStatsNodeResponse.writeTo, deserializes the content, and renders the result using toXContent. + * Since ADStatsNodeResponse.writeTo uses StreamOutput::writeGenericValue, we can only use + * a List<Map<String, String>> instead of the Entity object itself as + * StreamOutput::writeGenericValue only recognizes built-in types. + * + * This functions returns a map consistent with what toXContent returns. + * + * @return a map representing the entity + */ + public List> toStat() { + List> res = new ArrayList<>(attributes.size() * 2); + for (Map.Entry attr : attributes.entrySet()) { + Map elements = new TreeMap<>(); + elements.put(ATTRIBUTE_NAME_FIELD, attr.getKey()); + elements.put(ATTRIBUTE_VALUE_FIELD, attr.getValue()); + res.add(elements); + } + return res; } public static Entity parse(XContentParser parser) throws IOException { + SortedMap entities = new TreeMap<>(); String parsedValue = null; String parsedName = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - switch (fieldName) { - case ENTITY_NAME_FIELD: - parsedName = parser.text(); - break; - case ENTITY_VALUE_FIELD: - parsedValue = parser.text(); - break; - default: - break; + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != Token.END_OBJECT) { + String fieldName = parser.currentName(); + // move to the field value + parser.nextToken(); + switch (fieldName) { + case ATTRIBUTE_NAME_FIELD: + parsedName = parser.text(); + break; + case ATTRIBUTE_VALUE_FIELD: + parsedValue = parser.text(); + break; + default: + break; + } + } + // reset every time I have seen a name-value pair. + if (parsedName != null && parsedValue != null) { + entities.put(parsedName, parsedValue); + parsedValue = null; + parsedName = null; } } - return new Entity(parsedName, parsedValue); + return new Entity(entities); } @Generated @@ -99,28 +214,209 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Entity that = (Entity) o; - return Objects.equal(getName(), that.getName()) && Objects.equal(getValue(), that.getValue()); + return Objects.equal(attributes, that.attributes); } @Generated @Override public int hashCode() { - return Objects.hashCode(getName(), getValue()); + return Objects.hashCode(attributes); } - @Generated - public String getName() { - return name; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); } - @Generated - public String getValue() { - return value; + /** + * Used to print Entity info and localizing a node in a hash ring. + * @return a normalized String representing the entity. + */ + @Override + public String toString() { + return normalizedAttributes(attributes); } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(name); - out.writeString(value); + /** + * Return a string of the attributes in the ascending order of attribute names + * @return a normalized String corresponding to the Map. The string is + * deterministic (i.e., no matter in what order we insert values, + * the returned the string is the same). This is to ensure keys with the + * same content mapped to the same node in our hash ring. + * + */ + private static String normalizedAttributes(SortedMap attributes) { + return Joiner.on(",").withKeyValueSeparator("=").join(attributes); + } + + /** + * Create model Id out of detector Id and attribute name and value pairs + * + * HCAD v1 uses the categorical value as part of the model document Id, but + * OpenSearch’s document Id can be at most 512 bytes. Categorical values are + * usually less than 256 characters, but can grow to 32766 in theory. + * HCAD v1 skips an entity if the entity's name is more than 256 characters. + * We cannot do that in v2 as that can reject a lot of entities. To overcome + * the obstacle, we hash categorical values to a 128-bit string (like SHA-1 + * that git uses) and use the hash as part of the model document Id. + * + * We have choices to make regarding when to use the hash as part of a model + * document Id: for all HC detectors or a HC detector with multiple categorical + * fields. The challenge lies in providing backward compatibility of looking for + * a model checkpoint in the case of a HC detector with one categorical field. + * If using hashes for all HC detectors, we need two get requests to ensure that + * a model checkpoint exists. One uses the document Id without a hash, while one + * uses the document Id with a hash. The dual get requests are ineffective. If + * limiting hashes to a HC detector with multiple categorical fields, there is + * no backward compatibility issue. However, the code will be branchy. One may + * wonder if backward compatibility can be ignored; indeed, the old checkpoints + * will be gone after a transition period during upgrading. During the transition + * period, HC detectors can experience unnecessary cold starts as if the + * detectors were just started. Checkpoint index size can double if every model + * has two model documents. The transition period can be as long as 3 days since + * our checkpoint retention period is 3 days. There is no perfect solution. We + * prefer limiting hashes to an HC detector with multiple categorical fields as + * its customer impact is none. + * + * @param detectorId Detector Id + * @param attributes Attributes of an entity + * @return the model Id + */ + public static Optional getModelId(String detectorId, SortedMap attributes) { + if (attributes.isEmpty()) { + return Optional.empty(); + } else if (attributes.size() == 1) { + for (Map.Entry categoryValuePair : attributes.entrySet()) { + // For OpenSearch, the limit of the document ID is 512 bytes. + // skip an entity if the entity's name is more than 256 characters + // since we are using it as part of document id. + String categoricalValue = categoryValuePair.getValue().toString(); + if (categoricalValue.length() > AnomalyDetectorSettings.MAX_ENTITY_LENGTH) { + return Optional.empty(); + } + return Optional.of(detectorId + MODEL_ID_INFIX + categoricalValue); + } + return Optional.empty(); + } else { + String normalizedFields = normalizedAttributes(attributes); + MurmurHash3.Hash128 hashFunc = MurmurHash3 + .hash128( + normalizedFields.getBytes(StandardCharsets.UTF_8), + 0, + normalizedFields.length(), + RANDOM_SEED, + new MurmurHash3.Hash128() + ); + // 16 bytes = 128 bits + byte[] bytes = new byte[16]; + System.arraycopy(Numbers.longToBytes(hashFunc.h1), 0, bytes, 0, 8); + System.arraycopy(Numbers.longToBytes(hashFunc.h2), 0, bytes, 8, 8); + // Some bytes like 10 in ascii is corrupted in some systems. Base64 ensures we use safe bytes: https://tinyurl.com/mxmrhmhf + return Optional.of(detectorId + MODEL_ID_INFIX + Base64.getUrlEncoder().withoutPadding().encodeToString(bytes)); + } + } + + /** + * Get the cached model Id if present. Or recompute one if missing. + * + * @param detectorId Detector Id. Used as part of model Id. + * @return Model Id. Can be missing (e.g., the field value is too long for single-category detector) + */ + public Optional getModelId(String detectorId) { + if (modelId.get() == null) { + // computing model id is not cheap and the result is deterministic. We only do it once. + Optional computedModelId = Entity.getModelId(detectorId, attributes); + if (computedModelId.isPresent()) { + this.modelId.set(computedModelId.get()); + } else { + this.modelId.set(null); + } + } + return Optional.ofNullable(modelId.get()); + } + + public Map getAttributes() { + return attributes; + } + + /** + * Generate multi-term query filter like + * GET /company/_search + { + "query": { + "bool": { + "filter": [ + { + "term": { + "ip": "1.2.3.4" + } + }, + { + "term": { + "name.keyword": "Kaituo" + } + } + ] + } + } + } + * + *@return a list of term query builder + */ + public List getTermQueryBuilders() { + List res = new ArrayList<>(); + for (Map.Entry attribute : attributes.entrySet()) { + res.add(new TermQueryBuilder(attribute.getKey(), attribute.getValue())); + } + return res; + } + + public List getTermQueryBuilders(String pathPrefix) { + List res = new ArrayList<>(); + for (Map.Entry attribute : attributes.entrySet()) { + res.add(new TermQueryBuilder(pathPrefix + attribute.getKey(), attribute.getValue())); + } + return res; + } + + /** + * From json to Entity instance + * @param entityValue json array consisting attributes + * @return Entity instance + * @throws IOException when there is an deserialization issue. + */ + public static Entity fromJsonArray(Object entityValue) throws IOException { + XContentBuilder content = JsonXContent.contentBuilder(); + content.startObject(); + content.field(CommonName.ENTITY_KEY, entityValue); + content.endObject(); + + try ( + InputStream stream = BytesReference.bytes(content).streamInput(); + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream) + ) { + // move to content.StartObject + parser.nextToken(); + // move to CommonName.ENTITY_KEY + parser.nextToken(); + // move to start of the array + parser.nextToken(); + return Entity.parse(parser); + } + } + + public static Optional fromJsonObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (false == CommonName.ENTITY_KEY.equals(parser.currentName())) { + // not an object with "entity" as the root key + return Optional.empty(); + } + // move to start of the array + parser.nextToken(); + return Optional.of(Entity.parse(parser)); + } + return Optional.empty(); } } diff --git a/src/main/java/org/opensearch/ad/model/EntityProfile.java b/src/main/java/org/opensearch/ad/model/EntityProfile.java index bc89b2e62..a3c8d2af8 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfile.java +++ b/src/main/java/org/opensearch/ad/model/EntityProfile.java @@ -44,34 +44,26 @@ */ public class EntityProfile implements Writeable, ToXContent, Mergeable { // field name in toXContent - public static final String CATEGORY_FIELD = "category_field"; - public static final String ENTITY_VALUE = "value"; public static final String IS_ACTIVE = "is_active"; public static final String LAST_ACTIVE_TIMESTAMP = "last_active_timestamp"; public static final String LAST_SAMPLE_TIMESTAMP = "last_sample_timestamp"; - private final String categoryField; - private final String value; private Boolean isActive; private long lastActiveTimestampMs; private long lastSampleTimestampMs; private InitProgressProfile initProgress; - private ModelProfile modelProfile; + private ModelProfileOnNode modelProfile; private EntityState state; public EntityProfile( - String categoryField, - String value, Boolean isActive, long lastActiveTimeStamp, long lastSampleTimestamp, InitProgressProfile initProgress, - ModelProfile modelProfile, + ModelProfileOnNode modelProfile, EntityState state ) { super(); - this.categoryField = categoryField; - this.value = value; this.isActive = isActive; this.lastActiveTimestampMs = lastActiveTimeStamp; this.lastSampleTimestampMs = lastSampleTimestamp; @@ -81,20 +73,13 @@ public EntityProfile( } public static class Builder { - private final String categoryField; - private final String value; private Boolean isActive = null; private long lastActiveTimestampMs = -1L; private long lastSampleTimestampMs = -1L; private InitProgressProfile initProgress = null; - private ModelProfile modelProfile = null; + private ModelProfileOnNode modelProfile = null; private EntityState state = EntityState.UNKNOWN; - public Builder(String categoryField, String value) { - this.categoryField = categoryField; - this.value = value; - } - public Builder isActive(Boolean isActive) { this.isActive = isActive; return this; @@ -115,7 +100,7 @@ public Builder initProgress(InitProgressProfile initProgress) { return this; } - public Builder modelProfile(ModelProfile modelProfile) { + public Builder modelProfile(ModelProfileOnNode modelProfile) { this.modelProfile = modelProfile; return this; } @@ -126,22 +111,11 @@ public Builder state(EntityState state) { } public EntityProfile build() { - return new EntityProfile( - categoryField, - value, - isActive, - lastActiveTimestampMs, - lastSampleTimestampMs, - initProgress, - modelProfile, - state - ); + return new EntityProfile(isActive, lastActiveTimestampMs, lastSampleTimestampMs, initProgress, modelProfile, state); } } public EntityProfile(StreamInput in) throws IOException { - this.categoryField = in.readString(); - this.value = in.readString(); this.isActive = in.readOptionalBoolean(); this.lastActiveTimestampMs = in.readLong(); this.lastSampleTimestampMs = in.readLong(); @@ -149,19 +123,11 @@ public EntityProfile(StreamInput in) throws IOException { this.initProgress = new InitProgressProfile(in); } if (in.readBoolean()) { - this.modelProfile = new ModelProfile(in); + this.modelProfile = new ModelProfileOnNode(in); } this.state = in.readEnum(EntityState.class); } - public String getCategoryField() { - return categoryField; - } - - public String getValue() { - return value; - } - public Optional getActive() { return Optional.ofNullable(isActive); } @@ -192,7 +158,7 @@ public InitProgressProfile getInitProgress() { return initProgress; } - public ModelProfile getModelProfile() { + public ModelProfileOnNode getModelProfile() { return modelProfile; } @@ -207,8 +173,6 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CATEGORY_FIELD, categoryField); - builder.field(ENTITY_VALUE, value); if (isActive != null) { builder.field(IS_ACTIVE, isActive); } @@ -233,8 +197,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(categoryField); - out.writeString(value); out.writeOptionalBoolean(isActive); out.writeLong(lastActiveTimestampMs); out.writeLong(lastSampleTimestampMs); @@ -256,8 +218,6 @@ public void writeTo(StreamOutput out) throws IOException { @Override public String toString() { ToStringBuilder builder = new ToStringBuilder(this); - builder.append(CATEGORY_FIELD, categoryField); - builder.append(ENTITY_VALUE, value); if (isActive != null) { builder.append(IS_ACTIVE, isActive); } @@ -290,8 +250,6 @@ public boolean equals(Object obj) { if (obj instanceof EntityProfile) { EntityProfile other = (EntityProfile) obj; EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(categoryField, other.categoryField); - equalsBuilder.append(value, other.value); equalsBuilder.append(isActive, other.isActive); equalsBuilder.append(lastActiveTimestampMs, other.lastActiveTimestampMs); equalsBuilder.append(lastSampleTimestampMs, other.lastSampleTimestampMs); @@ -307,8 +265,6 @@ public boolean equals(Object obj) { @Override public int hashCode() { return new HashCodeBuilder() - .append(categoryField) - .append(value) .append(isActive) .append(lastActiveTimestampMs) .append(lastSampleTimestampMs) diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java b/src/main/java/org/opensearch/ad/model/ModelProfile.java index 725400363..9cb29b629 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfile.java +++ b/src/main/java/org/opensearch/ad/model/ModelProfile.java @@ -10,7 +10,7 @@ */ /* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -26,86 +26,80 @@ package org.opensearch.ad.model; -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - import java.io.IOException; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; -import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.XContentBuilder; -public class ModelProfile implements Writeable, ToXContent { - // field name in toXContent - public static final String MODEL_ID = "model_id"; - public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; - public static final String NODE_ID = "node_id"; - +/** + * Used to show model information in profile API + * + */ +public class ModelProfile implements Writeable, ToXContentObject { private final String modelId; + private final Entity entity; private final long modelSizeInBytes; - private final String nodeId; - public ModelProfile(String modelId, long modelSize, String nodeId) { + public ModelProfile(String modelId, Entity entity, long modelSizeInBytes) { super(); this.modelId = modelId; - this.modelSizeInBytes = modelSize; - this.nodeId = nodeId; + this.entity = entity; + this.modelSizeInBytes = modelSizeInBytes; } public ModelProfile(StreamInput in) throws IOException { - modelId = in.readString(); - modelSizeInBytes = in.readLong(); - nodeId = in.readString(); + this.modelId = in.readString(); + if (in.readBoolean()) { + this.entity = new Entity(in); + } else { + this.entity = null; + } + this.modelSizeInBytes = in.readLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeLong(modelSizeInBytes); } public String getModelId() { return modelId; } - public long getModelSize() { - return modelSizeInBytes; + public Entity getEntity() { + return entity; } - public String getNodeId() { - return nodeId; + public long getModelSizeInBytes() { + return modelSizeInBytes; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(MODEL_ID, modelId); + builder.field(CommonName.MODEL_ID_KEY, modelId); + if (entity != null) { + builder.field(CommonName.ENTITY_KEY, entity); + } if (modelSizeInBytes > 0) { - builder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + builder.field(CommonName.MODEL_SIZE_IN_BYTES, modelSizeInBytes); } - builder.field(NODE_ID, nodeId); - builder.endObject(); return builder; } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(modelId); - out.writeLong(modelSizeInBytes); - out.writeString(nodeId); - } - @Override public boolean equals(Object obj) { if (this == obj) @@ -118,8 +112,6 @@ public boolean equals(Object obj) { ModelProfile other = (ModelProfile) obj; EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(modelId, other.modelId); - equalsBuilder.append(modelSizeInBytes, other.modelSizeInBytes); - equalsBuilder.append(nodeId, other.nodeId); return equalsBuilder.isEquals(); } @@ -128,17 +120,19 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return new HashCodeBuilder().append(modelId).append(modelSizeInBytes).append(nodeId).toHashCode(); + return new HashCodeBuilder().append(modelId).toHashCode(); } @Override public String toString() { ToStringBuilder builder = new ToStringBuilder(this); - builder.append(MODEL_ID, modelId); + builder.append(CommonName.MODEL_ID_KEY, modelId); if (modelSizeInBytes > 0) { - builder.append(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + builder.append(CommonName.MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (entity != null) { + builder.append(CommonName.ENTITY_KEY, entity); } - builder.append(NODE_ID, nodeId); return builder.toString(); } } diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java new file mode 100644 index 000000000..d11978f70 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java @@ -0,0 +1,101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; + +public class ModelProfileOnNode implements Writeable, ToXContent { + // field name in toXContent + public static final String NODE_ID = "node_id"; + + private final String nodeId; + private final ModelProfile modelProfile; + + public ModelProfileOnNode(String nodeId, ModelProfile modelProfile) { + this.nodeId = nodeId; + this.modelProfile = modelProfile; + } + + public ModelProfileOnNode(StreamInput in) throws IOException { + this.nodeId = in.readString(); + this.modelProfile = new ModelProfile(in); + } + + public String getModelId() { + return modelProfile.getModelId(); + } + + public long getModelSize() { + return modelProfile.getModelSizeInBytes(); + } + + public String getNodeId() { + return nodeId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + modelProfile.toXContent(builder, params); + builder.field(NODE_ID, nodeId); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(nodeId); + modelProfile.writeTo(out); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof ModelProfileOnNode) { + ModelProfileOnNode other = (ModelProfileOnNode) obj; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(modelProfile, other.modelProfile); + equalsBuilder.append(nodeId, other.nodeId); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(modelProfile).append(nodeId).toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append(CommonName.MODEL, modelProfile); + builder.append(NODE_ID, nodeId); + return builder.toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java new file mode 100644 index 000000000..bbe3ffc00 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java @@ -0,0 +1,135 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +/** + * + * @param Individual request type that is a subtype of ADRequest + * @param Batch request type like BulkRequest + * @param Response type like BulkResponse + */ +public abstract class BatchWorker extends + ConcurrentWorker { + private static final Logger LOG = LogManager.getLogger(BatchWorker.class); + protected int batchSize; + + public BatchWorker( + String queueName, + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrencySetting, + executionTtl, + stateTtl, + nodeStateManager + ); + this.batchSize = batchSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(batchSizeSetting, it -> batchSize = it); + } + + /** + * Used by subclasses to creates customized logic to send batch requests. + * After everything finishes, the method should call listener. + * @param request Batch request to execute + * @param listener customized listener + */ + protected abstract void executeBatchRequest(BatchRequestType request, ActionListener listener); + + /** + * We convert from queued requests understood by AD to batchRequest understood by OpenSearch. + * @param toProcess Queued requests + * @return batch requests + */ + protected abstract BatchRequestType toBatchRequest(List toProcess); + + @Override + protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback) { + + List toProcess = getRequests(batchSize); + + // it is possible other concurrent threads have drained the queue + if (false == toProcess.isEmpty()) { + BatchRequestType batchRequest = toBatchRequest(toProcess); + + ThreadedActionListener listener = new ThreadedActionListener<>( + LOG, + threadPool, + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + getResponseListener(toProcess, batchRequest), + false + ); + + final ActionListener listenerWithRelease = ActionListener.runAfter(listener, afterProcessCallback); + executeBatchRequest(batchRequest, listenerWithRelease); + } else { + emptyQueueCallback.run(); + } + } + + /** + * Used by subclasses to creates customized logic to handle batch responses + * or errors. + * @param toProcess Queued request used to retrieve information of retrying requests + * @param batchRequest Batch request corresponding to toProcess. We convert + * from toProcess understood by AD to batchRequest understood by ES. + * @return Listener to BatchResponse + */ + protected abstract ActionListener getResponseListener(List toProcess, BatchRequestType batchRequest); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java new file mode 100644 index 000000000..ad9322a7f --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java @@ -0,0 +1,364 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.ad.util.ParseUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.threadpool.ThreadPool; + +/** + * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: + * a). If a checkpoint is not found, we forward that request to the cold start queue. + * b). When a request gets errors, the queue does not change its expiry time and puts + * that request to the end of the queue and automatically retries them before they expire. + * c) When a checkpoint is found, we load that point to memory and score the input + * data point and save the result if a complete model exists. Otherwise, we enqueue + * the sample. If we can host that model in memory (e.g., there is enough memory), + * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the + * updated checkpoint back to disk. + * + */ +public class CheckpointReadWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); + public static final String WORKER_NAME = "checkpoint-read"; + private final ModelManager modelManager; + private final CheckpointDao checkpointDao; + private final EntityColdStartWorker entityColdStartQueue; + private final ResultWriteWorker resultWriteQueue; + private final AnomalyDetectionIndices indexUtil; + private final CacheProvider cacheProvider; + private final CheckpointWriteWorker checkpointWriteQueue; + + public CheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ModelManager modelManager, + CheckpointDao checkpointDao, + EntityColdStartWorker entityColdStartQueue, + ResultWriteWorker resultWriteQueue, + NodeStateManager stateManager, + AnomalyDetectionIndices indexUtil, + CacheProvider cacheProvider, + Duration stateTtl, + CheckpointWriteWorker checkpointWriteQueue + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + CHECKPOINT_READ_QUEUE_CONCURRENCY, + executionTtl, + CHECKPOINT_READ_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + + this.modelManager = modelManager; + this.checkpointDao = checkpointDao; + this.entityColdStartQueue = entityColdStartQueue; + this.resultWriteQueue = resultWriteQueue; + this.indexUtil = indexUtil; + this.cacheProvider = cacheProvider; + this.checkpointWriteQueue = checkpointWriteQueue; + } + + @Override + protected void executeBatchRequest(MultiGetRequest request, ActionListener listener) { + checkpointDao.batchRead(request, listener); + } + + /** + * Convert the input list of EntityFeatureRequest to a multi-get request. + * RateLimitedRequestWorker.getRequests has already limited the number of + * requests in the input list. So toBatchRequest method can take the input + * and send the multi-get directly. + * @return The converted multi-get request + */ + @Override + protected MultiGetRequest toBatchRequest(List toProcess) { + MultiGetRequest multiGetRequest = new MultiGetRequest(); + for (EntityRequest request : toProcess) { + Optional modelId = request.getModelId(); + if (false == modelId.isPresent()) { + continue; + } + multiGetRequest.add(new MultiGetRequest.Item(CommonName.CHECKPOINT_INDEX_NAME, modelId.get())); + } + return multiGetRequest; + } + + @Override + protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { + return ActionListener.wrap(response -> { + final MultiGetItemResponse[] itemResponses = response.getResponses(); + Map successfulRequests = new HashMap<>(); + + // lazy init since we don't expect retryable requests to happen often + Set retryableRequests = null; + Set notFoundModels = null; + for (MultiGetItemResponse itemResponse : itemResponses) { + String modelId = itemResponse.getId(); + if (itemResponse.isFailed()) { + final Exception failure = itemResponse.getFailure().getFailure(); + if (failure instanceof IndexNotFoundException) { + for (EntityRequest origRequest : toProcess) { + // If it is checkpoint index not found exception, I don't + // need to retry as checkpoint read is bound to fail. Just + // send everything to the cold start queue and return. + entityColdStartQueue.put(origRequest); + } + return; + } else if (ExceptionUtil.isRetryAble(failure)) { + if (retryableRequests == null) { + retryableRequests = new HashSet<>(); + } + retryableRequests.add(modelId); + } else if (ExceptionUtil.isOverloaded(failure)) { + LOG.error("too many get AD model checkpoint requests or shard not available"); + setCoolDownStart(); + } else { + LOG.error("Unexpected failure", failure); + } + } else if (!itemResponse.getResponse().isExists()) { + // lazy init as we don't expect retrying happens often + if (notFoundModels == null) { + notFoundModels = new HashSet<>(); + } + notFoundModels.add(modelId); + } else { + successfulRequests.put(modelId, itemResponse); + } + } + + // deal with not found model + if (notFoundModels != null) { + for (EntityRequest origRequest : toProcess) { + Optional modelId = origRequest.getModelId(); + if (modelId.isPresent() && notFoundModels.contains(modelId.get())) { + // submit to cold start queue + entityColdStartQueue.put(origRequest); + } + } + } + + if (successfulRequests.isEmpty() && (retryableRequests == null || retryableRequests.isEmpty())) { + // don't need to proceed further since no checkpoint is available + return; + } + + processCheckpointIteration(0, toProcess, successfulRequests, retryableRequests); + }, exception -> { + if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not available"); + setCoolDownStart(); + } else if (ExceptionUtil.isRetryAble(exception)) { + // retry all of them + putAll(toProcess); + } else { + LOG.error("Fail to restore models", exception); + } + }); + } + + private void processCheckpointIteration( + int i, + List toProcess, + Map successfulRequests, + Set retryableRequests + ) { + if (i >= toProcess.size()) { + return; + } + + // whether we will process next response in callbacks + // if false, finally will process next checkpoints + boolean processNextInCallBack = false; + try { + EntityFeatureRequest origRequest = toProcess.get(i); + + Optional modelIdOptional = origRequest.getModelId(); + if (false == modelIdOptional.isPresent()) { + return; + } + + String detectorId = origRequest.getDetectorId(); + Entity entity = origRequest.getEntity(); + + String modelId = modelIdOptional.get(); + + MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId); + + if (checkpointResponse != null) { + // successful requests + Optional> checkpoint = checkpointDao + .processGetResponse(checkpointResponse.getResponse(), modelId); + + if (false == checkpoint.isPresent()) { + // checkpoint is too big + return; + } + + ModelState modelState = modelManager.processEntityCheckpoint(checkpoint, entity, modelId, detectorId); + + EntityModel entityModel = modelState.getModel(); + + ThresholdingResult result = null; + if (entityModel.getRcf() != null && entityModel.getThreshold() != null) { + result = modelManager.score(origRequest.getCurrentFeature(), modelId, modelState); + } else { + entityModel.addSample(origRequest.getCurrentFeature()); + } + + nodeStateManager + .getAnomalyDetector( + detectorId, + onGetDetector(origRequest, i, detectorId, result, toProcess, successfulRequests, retryableRequests, modelState) + ); + processNextInCallBack = true; + } else if (retryableRequests != null && retryableRequests.contains(modelId)) { + // failed requests + super.put(origRequest); + } + } finally { + if (false == processNextInCallBack) { + processCheckpointIteration(i + 1, toProcess, successfulRequests, retryableRequests); + } + } + } + + private ActionListener> onGetDetector( + EntityFeatureRequest origRequest, + int index, + String detectorId, + ThresholdingResult result, + List toProcess, + Map successfulRequests, + Set retryableRequests, + ModelState modelState + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + + if (result != null && result.getRcfScore() > 0) { + resultWriteQueue + .put( + new ResultWriteRequest( + origRequest.getExpirationEpochMs(), + detectorId, + result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + new AnomalyResult( + detectorId, + null, + result.getRcfScore(), + result.getGrade(), + result.getConfidence(), + ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getDetectorIntervalInMilliseconds()), + Instant.now(), + Instant.now(), + null, + origRequest.getEntity(), + detector.getUser(), + indexUtil.getSchemaVersion(ADIndex.RESULT), + modelState.getModelId() + ) + ) + ); + } + + // try to load to cache + boolean loaded = cacheProvider.get().hostIfPossible(detector, modelState); + + if (false == loaded) { + // not in memory. Maybe cold entities or some other entities + // have filled the slot while waiting for loading checkpoints. + checkpointWriteQueue.write(modelState, true, RequestPriority.LOW); + } + + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelState.getModelId()), exception); + nodeStateManager.setException(detectorId, exception); + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + }); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java new file mode 100644 index 000000000..a18c05d14 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import org.opensearch.action.index.IndexRequest; + +public class CheckpointWriteRequest extends QueuedRequest { + private final IndexRequest indexRequest; + + public CheckpointWriteRequest(long expirationEpochMs, String detectorId, RequestPriority priority, IndexRequest indexRequest) { + super(expirationEpochMs, detectorId, priority); + this.indexRequest = indexRequest; + } + + public IndexRequest getIndexRequest() { + return indexRequest; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java new file mode 100644 index 000000000..5a98ceae7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java @@ -0,0 +1,277 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +public class CheckpointWriteWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class); + public static final String WORKER_NAME = "checkpoint-write"; + + private final CheckpointDao checkpoint; + private final String indexName; + private final Duration checkpointInterval; + + public CheckpointWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + CheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + this.checkpoint = checkpoint; + this.indexName = indexName; + this.checkpointInterval = checkpointInterval; + } + + @Override + protected void executeBatchRequest(BulkRequest request, ActionListener listener) { + checkpoint.batchWrite(request, listener); + } + + @Override + protected BulkRequest toBatchRequest(List toProcess) { + final BulkRequest bulkRequest = new BulkRequest(); + for (CheckpointWriteRequest request : toProcess) { + bulkRequest.add(request.getIndexRequest()); + } + return bulkRequest; + } + + @Override + protected ActionListener getResponseListener(List toProcess, BulkRequest batchRequest) { + return ActionListener.wrap(response -> { + for (BulkItemResponse r : response.getItems()) { + if (r.getFailureMessage() != null) { + // maybe indicating a bug + // don't retry failed requests since checkpoints are too large (250KB+) + // Later maintenance window or cold start or cache remove will retry saving + LOG.error(r.getFailureMessage()); + } + } + }, exception -> { + if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not avialble"); + setCoolDownStart(); + } + + for (CheckpointWriteRequest request : toProcess) { + nodeStateManager.setException(request.getDetectorId(), exception); + } + + // don't retry failed requests since checkpoints are too large (250KB+) + // Later maintenance window or cold start or cache remove will retry saving + LOG.error("Fail to save models", exception); + }); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again, except this + * is a high priority request (e.g., from cold start). + * This method will update the input state's last checkpoint time if the + * checkpoint is staged (ready to be written in the next batch). + * @param modelState Model state + * @param forceWrite whether we should write no matter what + * @param priority how urgent the write is + */ + public void write(ModelState modelState, boolean forceWrite, RequestPriority priority) { + Instant instant = modelState.getLastCheckpointTime(); + if (!shouldSave(instant, forceWrite)) { + return; + } + + if (modelState.getModel() != null) { + String detectorId = modelState.getDetectorId(); + String modelId = modelState.getModelId(); + if (modelId == null || detectorId == null) { + return; + } + + nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(detectorId, modelId, modelState, priority)); + } + } + + private ActionListener> onGetDetector( + String detectorId, + String modelId, + ModelState modelState, + RequestPriority priority + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + try { + Map source = checkpoint.toIndexSource(modelState); + + // the model state is bloated or we have bugs, skip + if (source == null || source.isEmpty()) { + return; + } + + CheckpointWriteRequest request = new CheckpointWriteRequest( + System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), + detectorId, + priority, + new IndexRequest(indexName).id(modelId).source(source) + ); + + put(request); + } catch (Exception e) { + // Example exception: + // ConcurrentModificationException when calling toCheckpoint + // and updating rcf model at the same time. To prevent this, + // we need to have a deep copy of models or have a lock. Both + // options are costly. + // As we are gonna retry serializing either when the entity is + // evicted out of cache or during the next maintenance period, + // don't do anything when the exception happens. + LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); + } + + }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + } + + public void writeAll(List> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) { + ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + try { + List allRequests = new ArrayList<>(); + for (ModelState state : modelStates) { + Instant instant = state.getLastCheckpointTime(); + if (!shouldSave(instant, forceWrite)) { + continue; + } + + Map source = checkpoint.toIndexSource(state); + String modelId = state.getModelId(); + + // the model state is bloated, skip + if (source == null || source.isEmpty() || Strings.isEmpty(modelId)) { + continue; + } + + allRequests + .add( + new CheckpointWriteRequest( + System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), + detectorId, + priority, + new IndexRequest(indexName).id(modelId).source(source) + ) + ); + } + + putAll(allRequests); + } catch (Exception e) { + // Example exception: + // ConcurrentModificationException when calling toCheckpoint + // and updating rcf model at the same time. To prevent this, + // we need to have a deep copy of models or have a lock. Both + // options are costly. + // As we are gonna retry serializing either when the entity is + // evicted out of cache or during the next maintenance period, + // don't do anything when the exception happens. + LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", detectorId), e); + } + + }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + + nodeStateManager.getAnomalyDetector(detectorId, onGetForAll); + } + + /** + * Should we save the checkpoint or not + * @param lastCheckpointTIme Last checkpoint time + * @param forceWrite Save no matter what + * @return true when forceWrite is true or we haven't saved checkpoint in the + * last checkpoint interval; false otherwise + */ + private boolean shouldSave(Instant lastCheckpointTIme, boolean forceWrite) { + return (lastCheckpointTIme != Instant.MIN && lastCheckpointTIme.plus(checkpointInterval).isBefore(clock.instant())) || forceWrite; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java new file mode 100644 index 000000000..7943062cd --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java @@ -0,0 +1,178 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; + +/** + * A queue slowly releasing low-priority requests to CheckpointReadQueue + * + * ColdEntityQueue is a queue to absorb cold entities. Like hot entities, we load a cold + * entity's model checkpoint from disk, train models if the checkpoint is not found, + * query for missed features to complete a shingle, use the models to check whether + * the incoming feature is normal, update models, and save the detection results to disks.  + * Implementation-wise, we reuse the queues we have developed for hot entities. + * The differences are: we process hot entities as long as resources (e.g., AD + * thread pool has availability) are available, while we release cold entity requests + * to other queues at a slow controlled pace. Also, cold entity requests' priority is low. + * So only when there are no hot entity requests to process are we going to process cold + * entity requests.  + * + */ +public class ColdEntityWorker extends RateLimitedRequestWorker { + private static final Logger LOG = LogManager.getLogger(ColdEntityWorker.class); + public static final String WORKER_NAME = "cold-entity"; + + private volatile int batchSize; + private final CheckpointReadWorker checkpointReadQueue; + // indicate whether a future pull over cold entity queues is scheduled + private boolean scheduled; + private volatile int expectedExecutionTimeInSecsPerRequest; + + public ColdEntityWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointReadWorker checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + stateTtl, + nodeStateManager + ); + + this.batchSize = CHECKPOINT_READ_QUEUE_BATCH_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(CHECKPOINT_READ_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.checkpointReadQueue = checkpointReadQueue; + this.scheduled = false; + + this.expectedExecutionTimeInSecsPerRequest = AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS, it -> this.expectedExecutionTimeInSecsPerRequest = it); + } + + private void pullRequests() { + int pulledRequestSize = 0; + int filteredRequestSize = 0; + try { + List requests = getRequests(batchSize); + if (requests == null || requests.isEmpty()) { + return; + } + // pulledRequestSize > batchSize means there are more requests in the queue + pulledRequestSize = requests.size(); + // guarantee we only send low priority requests + List filteredRequests = requests + .stream() + .filter(request -> request.priority == RequestPriority.LOW) + .collect(Collectors.toList()); + if (!filteredRequests.isEmpty()) { + checkpointReadQueue.putAll(filteredRequests); + filteredRequestSize = filteredRequests.size(); + } + } catch (Exception e) { + LOG.error("Error enqueuing cold entity requests", e); + } finally { + if (pulledRequestSize < batchSize) { + scheduled = false; + } else { + // there might be more to fetch + // schedule a pull from queue every few seconds. + scheduled = true; + if (filteredRequestSize == 0) { + pullRequests(); + } else { + schedulePulling(getScheduleDelay(filteredRequestSize)); + } + } + } + } + + private synchronized void schedulePulling(TimeValue delay) { + try { + threadPool.schedule(this::pullRequests, delay, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + } catch (Exception e) { + LOG.error("Fail to schedule cold entity pulling", e); + } + } + + /** + * only pull requests to process when there's no other scheduled run + */ + @Override + protected void triggerProcess() { + if (false == scheduled) { + pullRequests(); + } + } + + /** + * The method calculates the delay we have to set to control the rate of cold + * entity processing. We wait longer if the requestSize is larger to give the + * system more time to processing requests. We ddd randomness to cope with the + * case that we want to execute at least 1 request every few seconds, but + * cannot guarantee that. + * @param requestSize requests to process + * @return the delay for the next scheduled run + */ + private TimeValue getScheduleDelay(int requestSize) { + int expectedSingleRequestExecutionMillis = 1000 * expectedExecutionTimeInSecsPerRequest; + int waitMilliSeconds = requestSize * expectedSingleRequestExecutionMillis; + return TimeValue.timeValueMillis(waitMilliSeconds + random.nextInt(waitMilliSeconds)); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java new file mode 100644 index 000000000..9861e5056 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Random; +import java.util.concurrent.Semaphore; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +/** + * A queue to run concurrent requests (either batch or single request). + * The concurrency is configurable. The callers use the put method to put requests + * in and the queue tries to execute them if there are concurrency slots. + * + * @param Individual request type that is a subtype of ADRequest + */ +public abstract class ConcurrentWorker extends RateLimitedRequestWorker { + private static final Logger LOG = LogManager.getLogger(ConcurrentWorker.class); + + private Semaphore permits; + + private Instant lastExecuteTime; + private Duration executionTtl; + + /** + * + * Constructor with dependencies and configuration. + * + * @param queueName queue's name + * @param heapSizeInBytes ES heap size + * @param singleRequestSizeInBytes single request's size in bytes + * @param maxHeapPercentForQueueSetting max heap size used for the queue. Used for + * rate AD's usage on ES threadpools. + * @param clusterService Cluster service accessor + * @param random Random number generator + * @param adCircuitBreakerService AD Circuit breaker service + * @param threadPool threadpool accessor + * @param settings Cluster settings getter + * @param maxQueuedTaskRatio maximum queued tasks ratio in ES threadpools + * @param clock Clock to get current time + * @param mediumSegmentPruneRatio the percent of medium priority requests to prune when the queue is full + * @param lowSegmentPruneRatio the percent of low priority requests to prune when the queue is full + * @param maintenanceFreqConstant a constant help define the frequency of maintenance. We cannot do + * the expensive maintenance too often. + * @param concurrencySetting Max concurrent processing of the queued events + * @param executionTtl Max execution time of a single request + * @param stateTtl max idle state duration. Used to clean unused states. + * @param nodeStateManager node state accessor + */ + public ConcurrentWorker( + String queueName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + stateTtl, + nodeStateManager + ); + + this.permits = new Semaphore(concurrencySetting.get(settings)); + clusterService.getClusterSettings().addSettingsUpdateConsumer(concurrencySetting, it -> permits = new Semaphore(it)); + + this.lastExecuteTime = clock.instant(); + this.executionTtl = executionTtl; + } + + @Override + public void maintenance() { + super.maintenance(); + + if (lastExecuteTime.plus(executionTtl).isBefore(clock.instant()) && permits.availablePermits() == 0 && false == isQueueEmpty()) { + LOG.warn("previous execution has been running for too long. Maybe there are bugs."); + + // Release one permit. This is a stop gap solution as I don't know + // whether the system is under heavy workload or not. Release multiple + // permits might cause the situation even worse. So I am conservative here. + permits.release(); + } + } + + /** + * try to execute queued requests if there are concurrency slots and return right away. + */ + @Override + protected void triggerProcess() { + threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> { + if (permits.tryAcquire()) { + try { + lastExecuteTime = clock.instant(); + execute(() -> { + permits.release(); + process(); + }, () -> { permits.release(); }); + } catch (Exception e) { + permits.release(); + // throw to the root level to catch + throw e; + } + } + }); + } + + /** + * Execute requests in toProcess. The implementation needs to call cleanUp after done. + * The 1st callback is executed after processing one request. So we keep looking for + * new requests if there is any after finishing one request. Otherwise, just release + * (the 2nd callback) without calling process. + * @param afterProcessCallback callback after processing requests + * @param emptyQueueCallback callback for empty queues + */ + protected abstract void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java new file mode 100644 index 000000000..67ef06e34 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Locale; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +/** + * A queue for HCAD model training (a.k.a. cold start). As model training is a + * pretty expensive operation, we pull cold start requests from the queue in a + * serial fashion. Each detector has an equal chance of being pulled. The equal + * probability is achieved by putting model training requests for different + * detectors into different segments and pulling requests from segments in a + * round-robin fashion. + * + */ +public class EntityColdStartWorker extends SingleRequestWorker { + private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); + public static final String WORKER_NAME = "cold-start"; + + private final EntityColdStarter entityColdStarter; + + public EntityColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + EntityColdStarter entityColdStarter, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + ENTITY_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + stateTtl, + nodeStateManager + ); + this.entityColdStarter = entityColdStarter; + } + + @Override + protected void executeRequest(EntityRequest coldStartRequest, ActionListener listener) { + String detectorId = coldStartRequest.getDetectorId(); + + Optional modelId = coldStartRequest.getModelId(); + + if (false == modelId.isPresent()) { + String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); + LOG.warn(error); + listener.onFailure(new RuntimeException(error)); + return; + } + + ModelState modelState = new ModelState<>( + new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null, null), + modelId.get(), + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + ActionListener failureListener = ActionListener.delegateResponse(listener, (delegateListener, e) -> { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(detectorId, e); + delegateListener.onFailure(e); + }); + + entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, failureListener); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java new file mode 100644 index 000000000..497670db9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import org.opensearch.ad.model.Entity; + +public class EntityFeatureRequest extends EntityRequest { + private final double[] currentFeature; + private final long dataStartTimeMillis; + + public EntityFeatureRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + Entity entity, + double[] currentFeature, + long dataStartTimeMs + ) { + super(expirationEpochMs, detectorId, priority, entity); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + } + + public double[] getCurrentFeature() { + return currentFeature; + } + + public long getDataStartTimeMillis() { + return dataStartTimeMillis; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java new file mode 100644 index 000000000..2fba9f4d9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.util.Optional; + +import org.opensearch.ad.model.Entity; + +public class EntityRequest extends QueuedRequest { + private final Entity entity; + + /** + * + * @param expirationEpochMs Expiry time of the request + * @param detectorId Detector Id + * @param priority the entity's priority + * @param entity the entity's attributes + */ + public EntityRequest(long expirationEpochMs, String detectorId, RequestPriority priority, Entity entity) { + super(expirationEpochMs, detectorId, priority); + this.entity = entity; + } + + public Entity getEntity() { + return entity; + } + + public Optional getModelId() { + return entity.getModelId(detectorId); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java b/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java new file mode 100644 index 000000000..adfbe8c4c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +public abstract class QueuedRequest { + protected long expirationEpochMs; + protected String detectorId; + protected RequestPriority priority; + + /** + * + * @param expirationEpochMs Request expiry time in milliseconds + * @param detectorId Detector Id + * @param priority how urgent the request is + */ + protected QueuedRequest(long expirationEpochMs, String detectorId, RequestPriority priority) { + this.expirationEpochMs = expirationEpochMs; + this.detectorId = detectorId; + this.priority = priority; + } + + public long getExpirationEpochMs() { + return expirationEpochMs; + } + + /** + * A queue consists of various segments with different priority. A queued + * request belongs one segment. The subtype will define the id. + * @return Segment Id + */ + public RequestPriority getPriority() { + return priority; + } + + public void setPriority(RequestPriority priority) { + this.priority = priority; + } + + public String getDetectorId() { + return detectorId; + } + + public void setDetectorId(String detectorId) { + this.detectorId = detectorId; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java new file mode 100644 index 000000000..652daf1c6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java @@ -0,0 +1,572 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.MaintenanceState; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.common.exception.AnomalyDetectionException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.threadpool.ThreadPoolStats; + +/** + * HCAD can bombard Opensearch with “thundering herd” traffic, in which many entities + * make requests that need similar Opensearch reads/writes at approximately the same + * time. To remedy this issue we queue the requests and ensure that only a + * limited set of requests are out for Opensearch reads/writes. + * + * @param Individual request type that is a subtype of ADRequest + */ +public abstract class RateLimitedRequestWorker implements MaintenanceState { + /** + * Each request is associated with a RequestQueue. That is, a queue consists of RequestQueues. + * RequestQueues have their corresponding priorities: HIGH, MEDIUM, and LOW. An example + * of HIGH priority requests is anomaly results with errors or its anomaly grade + * larger than zero. An example of MEDIUM priority requests is a cold start request + * for an entity. An example of LOW priority requests is checkpoint write requests + * for a cold entity. LOW priority requests have the slightest chance to be selected + * to be executed. MEDIUM and HIGH priority requests have higher stakes. LOW priority + * requests have higher chances of being deleted when the size of the queue reaches + * beyond a limit compared to MEDIUM/HIGH priority requests. + * + */ + class RequestQueue implements ExpiringState { + /* + * last access time of the RequestQueue + * This does not have to be precise, just a signal for unused old RequestQueue + * that can be removed. It is fine if we have race condition. Don't want + * to synchronize the access as this could penalize performance. + */ + private Instant lastAccessTime; + // data structure to hold requests. Cannot be reassigned. This is to + // guarantee a RequestQueue's content cannot be null. + private final BlockingQueue content; + + RequestQueue() { + this.lastAccessTime = clock.instant(); + this.content = new LinkedBlockingQueue(); + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastAccessTime, stateTtl, clock.instant()); + } + + public void put(RequestType request) throws InterruptedException { + this.content.put(request); + } + + public int size() { + return this.content.size(); + } + + public boolean isEmpty() { + return content.size() == 0; + } + + /** + * Remove requests in the queue + * @param numberToRemove number of requests to remove + * @return removed requests + */ + public int drain(int numberToRemove) { + int removed = 0; + while (removed <= numberToRemove) { + if (content.poll() != null) { + removed++; + } else { + // stop if the queue is empty + break; + } + } + return removed; + } + + /** + * Remove requests in the queue + * @param removeRatio the removing ratio + * @return removed requests + */ + public int drain(float removeRatio) { + int numberToRemove = (int) (content.size() * removeRatio); + return drain(numberToRemove); + } + + /** + * Remove expired requests + * + * In terms of request duration, HCAD throws a request out if it + * is older than the detector frequency. This duration limit frees + * up HCAD to work on newer requests in the subsequent detection + * interval instead of piling up requests that no longer matter. + * For example, loading model checkpoints for cache misses requires + * a queue configured in front of it. A request contains the checkpoint + * document Id and the expiry time, and the queue can hold a considerable + * volume of such requests since the size of the request is small. + * The expiry time is the start timestamp of the next detector run. + * Enforcing the expiry time places an upper bound on each request’s + * lifetime. + * + * @return the number of removed requests + */ + public int clearExpiredRequests() { + int removed = 0; + RequestType head = content.peek(); + while (head != null && head.getExpirationEpochMs() < clock.millis()) { + content.poll(); + removed++; + head = content.peek(); + } + return removed; + } + } + + private static final Logger LOG = LogManager.getLogger(RateLimitedRequestWorker.class); + + protected volatile int queueSize; + protected final String workerName; + private final long heapSize; + private final int singleRequestSize; + private float maxHeapPercentForQueue; + + // map from RequestQueue Id to its RequestQueue. + // For high priority requests, the RequestQueue id is RequestPriority.HIGH.name(). + // For low priority requests, the RequestQueue id is RequestPriority.LOW.name(). + // For medium priority requests, the RequestQueue id is detector id. The objective + // is to separate requests from different detectors and fairly process requests + // from each detector. + protected final ConcurrentSkipListMap requestQueues; + private String lastSelectedRequestQueueId; + protected Random random; + private ADCircuitBreakerService adCircuitBreakerService; + protected ThreadPool threadPool; + protected Instant cooldownStart; + protected int coolDownMinutes; + private float maxQueuedTaskRatio; + protected Clock clock; + private float mediumRequestQueuePruneRatio; + private float lowRequestQueuePruneRatio; + protected int maintenanceFreqConstant; + private final Duration stateTtl; + protected final NodeStateManager nodeStateManager; + + public RateLimitedRequestWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumRequestQueuePruneRatio, + float lowRequestQueuePruneRatio, + int maintenanceFreqConstant, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + this.heapSize = heapSizeInBytes; + this.singleRequestSize = singleRequestSizeInBytes; + this.maxHeapPercentForQueue = maxHeapPercentForQueueSetting.get(settings); + this.queueSize = (int) (heapSizeInBytes * maxHeapPercentForQueue / singleRequestSizeInBytes); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxHeapPercentForQueueSetting, it -> { + int oldQueueSize = queueSize; + this.maxHeapPercentForQueue = it; + this.queueSize = (int) (this.heapSize * maxHeapPercentForQueue / this.singleRequestSize); + LOG.info(new ParameterizedMessage("Queue size changed from [{}] to [{}]", oldQueueSize, queueSize)); + }); + + this.workerName = workerName; + this.random = random; + this.adCircuitBreakerService = adCircuitBreakerService; + this.threadPool = threadPool; + this.maxQueuedTaskRatio = maxQueuedTaskRatio; + this.clock = clock; + this.mediumRequestQueuePruneRatio = mediumRequestQueuePruneRatio; + this.lowRequestQueuePruneRatio = lowRequestQueuePruneRatio; + + this.lastSelectedRequestQueueId = null; + this.requestQueues = new ConcurrentSkipListMap<>(); + this.cooldownStart = Instant.MIN; + this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); + this.maintenanceFreqConstant = maintenanceFreqConstant; + this.stateTtl = stateTtl; + this.nodeStateManager = nodeStateManager; + } + + protected String getWorkerName() { + return workerName; + } + + /** + * To add fairness to multiple detectors, HCAD allocates queues at a per + * detector granularity and pulls off requests across similar queues in a + * round-robin fashion. This way, if one detector has a much higher + * cardinality than other detectors, the unfinished portion of that + * detector’s workload times out, and other detectors’ workloads continue + * operating with predictable performance. For example, for loading checkpoints, + * HCAD pulls off 10 requests from one detector’ queues, issues a mget request + * to ES, wait for it to finish, and then does it again for other detectors’ + * queues. If one queue does not have more than 10 requests, HCAD dequeues + * the next batches of messages in the round-robin schedule. + * @return next queue to fetch requests + */ + protected Optional> selectNextQueue() { + if (true == requestQueues.isEmpty()) { + return Optional.empty(); + } + + String startId = lastSelectedRequestQueueId; + try { + for (int i = 0; i < requestQueues.size(); i++) { + if (startId == null || requestQueues.size() == 1 || startId.equals(requestQueues.lastKey())) { + startId = requestQueues.firstKey(); + } else { + startId = requestQueues.higherKey(startId); + } + + if (startId.equals(RequestPriority.LOW.name())) { + continue; + } + + RequestQueue requestQueue = requestQueues.get(startId); + if (requestQueue == null) { + continue; + } + + requestQueue.clearExpiredRequests(); + + if (false == requestQueue.isEmpty()) { + return Optional.of(requestQueue.content); + } + } + + RequestQueue requestQueue = requestQueues.get(RequestPriority.LOW.name()); + + if (requestQueue != null) { + requestQueue.clearExpiredRequests(); + if (false == requestQueue.isEmpty()) { + return Optional.of(requestQueue.content); + } + } + // if we haven't find a non-empty queue , return empty. + return Optional.empty(); + } finally { + // it is fine we may have race conditions. We are not trying to + // be precise. The objective is to select each RequestQueue with equal probability. + lastSelectedRequestQueueId = startId; + } + } + + protected void putOnly(RequestType request) { + try { + // consider MEDIUM priority here because only medium priority RequestQueues use + // detector id as the key of the RequestQueue map. low and high priority requests + // just use the RequestQueue priority (i.e., low or high) as the key of the RequestQueue map. + RequestQueue requestQueue = requestQueues + .computeIfAbsent( + RequestPriority.MEDIUM == request.getPriority() ? request.getDetectorId() : request.getPriority().name(), + k -> new RequestQueue() + ); + + requestQueue.lastAccessTime = clock.instant(); + requestQueue.put(request); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Failed to add requests to [{}]", this.workerName), e); + } + } + + private void maintainForThreadPool() { + for (final ThreadPoolStats.Stats stats : threadPool.stats()) { + String name = stats.getName(); + // we mostly use these 3 threadpools + if (ThreadPool.Names.SEARCH.equals(name) || ThreadPool.Names.GET.equals(name) || ThreadPool.Names.WRITE.equals(name)) { + int maxQueueSize = (int) (maxQueuedTaskRatio * threadPool.info(name).getQueueSize().singles()); + // in case that users set queue size to -1 (unbounded) + if (maxQueueSize > 0 && stats.getQueue() > maxQueueSize) { + setCoolDownStart(); + break; + } + } + } + } + + private void prune(Map requestQueues) { + // pruning expired requests + pruneExpired(); + + // prune a few requests in each queue + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + if (requestQueueEntry.getKey().equals(RequestPriority.HIGH.name())) { + continue; + } + + RequestQueue requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null || requestQueue.isEmpty()) { + continue; + } + + // remove more requests in the low priority RequestQueue + if (requestQueueEntry.getKey().equals(RequestPriority.LOW.name())) { + requestQueue.drain(lowRequestQueuePruneRatio); + } else { + requestQueue.drain(mediumRequestQueuePruneRatio); + } + } + } + + /** + * pruning expired requests + * + * @return the total number of deleted requests + */ + private int pruneExpired() { + int deleted = 0; + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + RequestQueue requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null) { + continue; + } + + deleted += requestQueue.clearExpiredRequests(); + } + return deleted; + } + + private void prune(Map requestQueues, int exceededSize) { + // pruning expired requests + int leftItemsToRemove = exceededSize - pruneExpired(); + + if (leftItemsToRemove <= 0) { + return; + } + + // used to compute the average number of requests to remove in medium priority queues + int numberOfQueuesToExclude = 0; + + // remove low-priority requests + RequestQueue requestQueue = requestQueues.get(RequestPriority.LOW.name()); + if (requestQueue != null) { + int removedFromLow = requestQueue.drain(leftItemsToRemove); + if (removedFromLow >= leftItemsToRemove) { + return; + } else { + numberOfQueuesToExclude++; + leftItemsToRemove -= removedFromLow; + } + } + + // skip high-priority requests + if (requestQueues.get(RequestPriority.HIGH.name()) != null) { + numberOfQueuesToExclude++; + } + + int numberOfRequestsToRemoveInMediumQueues = leftItemsToRemove / (requestQueues.size() - numberOfQueuesToExclude); + + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + if (requestQueueEntry.getKey().equals(RequestPriority.HIGH.name()) + || requestQueueEntry.getKey().equals(RequestPriority.LOW.name())) { + continue; + } + + requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null) { + continue; + } + + requestQueue.drain(numberOfRequestsToRemoveInMediumQueues); + } + } + + private void maintainForMemory() { + // removed expired RequestQueue + maintenance(requestQueues, stateTtl); + + int exceededSize = exceededSize(); + if (exceededSize > 0) { + prune(requestQueues, exceededSize); + } else if (adCircuitBreakerService.isOpen()) { + // remove a few items in each RequestQueue + prune(requestQueues); + } + } + + private int exceededSize() { + Collection queues = requestQueues.values(); + int totalSize = 0; + + // When faced with a backlog beyond the limit, we prefer fresh requests + // and throws away old requests. + // release space so that put won't block + for (RequestQueue q : queues) { + totalSize += q.size(); + } + return totalSize - queueSize; + } + + public boolean isQueueEmpty() { + Collection queues = requestQueues.values(); + for (RequestQueue q : queues) { + if (q.size() > 0) { + return false; + } + } + return true; + } + + @Override + public void maintenance() { + try { + maintainForMemory(); + maintainForThreadPool(); + } catch (Exception e) { + LOG.warn("Failed to maintain", e); + } + } + + /** + * Start cooldown during a overloaded situation + */ + protected void setCoolDownStart() { + cooldownStart = clock.instant(); + } + + /** + * @param batchSize the max number of requests to fetch + * @return a list of batchSize requests (can be less) + */ + protected List getRequests(int batchSize) { + List toProcess = new ArrayList<>(batchSize); + + Set> selectedQueue = new HashSet<>(); + + while (toProcess.size() < batchSize) { + Optional> queue = selectNextQueue(); + if (false == queue.isPresent()) { + // no queue has requests + break; + } + + BlockingQueue nextToProcess = queue.get(); + if (selectedQueue.contains(nextToProcess)) { + // we have gone around all of the queues + break; + } + selectedQueue.add(nextToProcess); + + List requests = new ArrayList<>(); + // concurrent requests will wait to prevent concurrent draining. + // This is fine since the operation is fast + nextToProcess.drainTo(requests, batchSize); + toProcess.addAll(requests); + } + + return toProcess; + } + + /** + * Enqueuing runs asynchronously: we put requests in a queue, try to execute + * them. The thread executing requests won't block the thread inserting + * requests to the queue. + * @param request Individual request + */ + public void put(RequestType request) { + if (request == null) { + return; + } + putOnly(request); + + process(); + } + + public void putAll(List requests) { + if (requests == null || requests.isEmpty()) { + return; + } + try { + for (RequestType request : requests) { + putOnly(request); + } + + process(); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Failed to add requests to [{}]", getWorkerName()), e); + } + } + + protected void process() { + if (random.nextInt(maintenanceFreqConstant) == 1) { + maintenance(); + } + + // still in cooldown period + if (cooldownStart.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + threadPool.schedule(() -> { + try { + process(); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to process requests in [{}].", this.workerName), e); + } + }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + } else { + try { + triggerProcess(); + } catch (Exception e) { + LOG.error(String.format(Locale.ROOT, "Failed to process requests from %s", getWorkerName()), e); + if (e != null && e instanceof AnomalyDetectionException) { + AnomalyDetectionException adExep = (AnomalyDetectionException) e; + nodeStateManager.setException(adExep.getAnomalyDetectorId(), adExep); + } + } + + } + } + + /** + * How to execute requests is abstracted out and left to RateLimitedQueue's subclasses to implement. + */ + protected abstract void triggerProcess(); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java b/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java new file mode 100644 index 000000000..3193d2285 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +public enum RequestPriority { + LOW, + MEDIUM, + HIGH +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java new file mode 100644 index 000000000..40240c948 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import org.opensearch.ad.model.AnomalyResult; + +public class ResultWriteRequest extends QueuedRequest { + private final AnomalyResult result; + + public ResultWriteRequest(long expirationEpochMs, String detectorId, RequestPriority priority, AnomalyResult result) { + super(expirationEpochMs, detectorId, priority); + this.result = result; + } + + public AnomalyResult getResult() { + return result; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java new file mode 100644 index 000000000..84ba24128 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.threadpool.ThreadPool; + +public class ResultWriteWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(ResultWriteWorker.class); + public static final String WORKER_NAME = "result-write"; + + private final MultiEntityResultHandler resultHandler; + private NamedXContentRegistry xContentRegistry; + + public ResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + MultiEntityResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + this.resultHandler = resultHandler; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void executeBatchRequest(ADResultBulkRequest request, ActionListener listener) { + if (request.numberOfActions() < 1) { + listener.onResponse(null); + return; + } + resultHandler.flush(request, listener); + } + + @Override + protected ADResultBulkRequest toBatchRequest(List toProcess) { + final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); + for (ResultWriteRequest request : toProcess) { + bulkRequest.add(request.getResult()); + } + return bulkRequest; + } + + @Override + protected ActionListener getResponseListener( + List toProcess, + ADResultBulkRequest bulkRequest + ) { + return ActionListener.wrap(adResultBulkResponse -> { + if (adResultBulkResponse == null || false == adResultBulkResponse.getRetryRequests().isPresent()) { + // all successful + return; + } + + enqueueRetryRequestIteration(adResultBulkResponse.getRetryRequests().get(), 0); + }, exception -> { + if (ExceptionUtil.isRetryAble(exception)) { + // retry all of them + super.putAll(toProcess); + } else if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not avialble"); + setCoolDownStart(); + } + + for (ResultWriteRequest request : toProcess) { + nodeStateManager.setException(request.getDetectorId(), exception); + } + LOG.error("Fail to save results", exception); + }); + } + + private void enqueueRetryRequestIteration(List requestToRetry, int index) { + if (index >= requestToRetry.size()) { + return; + } + DocWriteRequest currentRequest = requestToRetry.get(index); + Optional resultToRetry = getAnomalyResult(currentRequest); + if (false == resultToRetry.isPresent()) { + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + AnomalyResult result = resultToRetry.get(); + String detectorId = result.getDetectorId(); + nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(requestToRetry, index, detectorId, result)); + } + + private ActionListener> onGetDetector( + List requestToRetry, + int index, + String detectorId, + AnomalyResult resultToRetry + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + super.put( + new ResultWriteRequest( + // expire based on execute start time + resultToRetry.getExecutionStartTime().toEpochMilli() + detector.getDetectorIntervalInMilliseconds(), + detectorId, + resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, + resultToRetry + ) + ); + + enqueueRetryRequestIteration(requestToRetry, index + 1); + + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); + enqueueRetryRequestIteration(requestToRetry, index + 1); + }); + } + + private Optional getAnomalyResult(DocWriteRequest request) { + try { + if (false == (request instanceof IndexRequest)) { + LOG.error(new ParameterizedMessage("We should only send IndexRquest, but get [{}].", request)); + return Optional.empty(); + } + // we send IndexRequest previously + IndexRequest indexRequest = (IndexRequest) request; + BytesReference indexSource = indexRequest.source(); + XContentType indexContentType = indexRequest.getContentType(); + try ( + XContentParser xContentParser = XContentHelper + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, indexSource, indexContentType) + ) { + // the first character is null. Without skipping it, we get + // org.opensearch.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found + // [null] + xContentParser.nextToken(); + return Optional.of(AnomalyResult.parse(xContentParser)); + } + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to parse index request [{}]", request), e); + } + return Optional.empty(); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java b/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java new file mode 100644 index 000000000..028a0643f --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.BlockingQueue; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +public abstract class SingleRequestWorker extends ConcurrentWorker { + private static final Logger LOG = LogManager.getLogger(SingleRequestWorker.class); + + public SingleRequestWorker( + String queueName, + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrencySetting, + executionTtl, + stateTtl, + nodeStateManager + ); + } + + @Override + protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback) { + RequestType request = null; + + Optional> queueOptional = selectNextQueue(); + if (false == queueOptional.isPresent()) { + // no queue has requests + emptyQueueCallback.run(); + return; + } + + BlockingQueue queue = queueOptional.get(); + if (false == queue.isEmpty()) { + request = queue.poll(); + } + + if (request == null) { + emptyQueueCallback.run(); + return; + } + + final ActionListener handlerWithRelease = ActionListener.wrap(afterProcessCallback); + executeRequest(request, handlerWithRelease); + } + + /** + * Used by subclasses to creates customized logic to send batch requests. + * After everything finishes, the method should call listener. + * @param request request to execute + * @param listener customized listener + */ + protected abstract void executeRequest(RequestType request, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index b1be4e362..45bad818b 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -27,22 +27,25 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.util.RestHandlerUtils.DETECTOR_ID; -import static org.opensearch.ad.util.RestHandlerUtils.ENTITY; import static org.opensearch.ad.util.RestHandlerUtils.PROFILE; import static org.opensearch.ad.util.RestHandlerUtils.TYPE; import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.ad.transport.GetAnomalyDetectorAction; import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; @@ -72,7 +75,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); String typesStr = request.param(TYPE); - String entityValue = request.param(ENTITY); + String rawPath = request.rawPath(); boolean returnJob = request.paramAsBoolean("job", false); boolean returnTask = request.paramAsBoolean("task", false); @@ -85,7 +88,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli typesStr, rawPath, all, - entityValue + buildEntity(request, detectorId) ); return channel -> client @@ -94,7 +97,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli @Override public List routes() { - return ImmutableList.of(); + return ImmutableList + .of( + // Opensearch-only API. Considering users may provide entity in the search body, support POST as well. + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE) + ), + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE) + ) + ); } @Override @@ -129,4 +143,35 @@ public List replacedRoutes() { ) ); } + + private Entity buildEntity(RestRequest request, String detectorId) throws IOException { + if (Strings.isEmpty(detectorId)) { + throw new IllegalStateException(CommonErrorMessages.AD_ID_MISSING_MSG); + } + + String entityName = request.param(CommonName.CATEGORICAL_FIELD); + String entityValue = request.param(CommonName.ENTITY_KEY); + + if (entityName != null && entityValue != null) { + // single-stream profile request: + // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + return Entity.createSingleAttributeEntity(detectorId, entityName, entityValue); + } else if (request.hasContent()) { + /* HCAD profile request: + * GET _plugins/_anomaly_detection/detectors//_profile/init_progress + * { + * "entity": [{ + * "name": "clientip", + * "value": "13.24.0.0" + * }] + * } + */ + Optional entity = Entity.fromJsonObject(request.contentParser()); + if (entity.isPresent()) { + return entity.get(); + } + } + // not a valid profile request with correct entity information + return null; + } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index 39170d138..c5913ffda 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -57,9 +57,11 @@ import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.NumericSetting; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.ad.util.RestHandlerUtils; @@ -87,7 +89,6 @@ public class IndexAnomalyDetectorActionHandler { public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create multi-entity anomaly detectors more than "; public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = "Can't create single-entity anomaly detectors more than "; public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document found in indices: "; - public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; public static final String NOT_FOUND_ERR_MSG = "Cannot found the categorical field %s"; @@ -346,11 +347,12 @@ private void validateCategoricalField(String detectorId) { return; } - // we only support one categorical field - // If there is more than 1 field or none, AnomalyDetector's constructor + // we only support a certain number of categorical field + // If there is more fields than required, AnomalyDetector's constructor // throws IllegalArgumentException before reaching this line - if (categoryField.size() != 1) { - listener.onFailure(new IllegalArgumentException(ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG)); + int maxCategoryFields = NumericSetting.maxCategoricalFields(); + if (categoryField.size() > maxCategoryFields) { + listener.onFailure(new IllegalArgumentException(CommonErrorMessages.getTooManyCategoricalFieldErr(maxCategoryFields))); return; } diff --git a/src/main/java/org/opensearch/ad/settings/AbstractSetting.java b/src/main/java/org/opensearch/ad/settings/AbstractSetting.java new file mode 100644 index 000000000..e80fcbde9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/AbstractSetting.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; + +/** + * A container serving dynamic numeric setting. The caller does not have to call + * ClusterSettings.addSettingsUpdateConsumer and can access the most-up-to-date + * value using the singleton instance. This is convenient for a setting that's + * accessed by various places or it is not possible to install ClusterSettings.addSettingsUpdateConsumer + * as the enclosing instances are not singleton (i.e. deleted after use). + * + */ +public abstract class AbstractSetting { + private static Logger logger = LogManager.getLogger(AbstractSetting.class); + + private ClusterService clusterService; + /** Latest setting value for each registered key. Thread-safe is required. */ + private final Map latestSettings = new ConcurrentHashMap<>(); + + private final Map> settings; + + protected AbstractSetting(Map> settings) { + this.settings = settings; + } + + private void setSettingsUpdateConsumers() { + for (Setting setting : settings.values()) { + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, newVal -> { + logger.info("[AD] The value of setting [{}] changed to [{}]", setting.getKey(), newVal); + latestSettings.put(setting.getKey(), newVal); + }); + } + } + + public void init(ClusterService clusterService) { + this.clusterService = clusterService; + setSettingsUpdateConsumers(); + } + + /** + * Get setting value by key. Return default value if not configured explicitly. + * + * @param key setting key. + * @param Setting type + * @return T setting value or default + */ + @SuppressWarnings("unchecked") + public T getSettingValue(String key) { + return (T) latestSettings.getOrDefault(key, getSetting(key).getDefault(Settings.EMPTY)); + } + + private Setting getSetting(String key) { + if (settings.containsKey(key)) { + return settings.get(key); + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); + } + + public List> getSettings() { + return new ArrayList<>(settings.values()); + } +} diff --git a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java index 1c5fbff31..0d67978d6 100644 --- a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java @@ -100,14 +100,20 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting AD_RESULT_HISTORY_MAX_DOCS = Setting + // Opensearch-only setting. Doesn't plan to use the value of the legacy setting + // AD_RESULT_HISTORY_MAX_DOCS as that's too low. If the master node uses opendistro code, + // it uses the legacy setting. If the master node uses opensearch code, it uses the new setting. + public static final Setting AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD = Setting .longSetting( - "plugins.anomaly_detection.ad_result_history_max_docs", - // Total documents in primary replica. - // A single feature result is roughly 150 bytes. Suppose a doc is - // of 200 bytes, 250 million docs is of 50 GB. We choose 50 GB - // because we have 1 shard at least. One shard can have at most 50 GB. - LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + "plugins.anomaly_detection.ad_result_history_max_docs_per_shard", + // Total documents in the primary shards. + // Note the count is for Lucene docs. Lucene considers a nested + // doc a doc too. One result corresponding to 4 Lucene docs. + // A single Lucene doc is roughly 46.8 bytes (measured by experiments). + // 1.35 billion docs is about 65 GB. One shard can have at most 65 GB. + // This number in Lucene doc count is used in RolloverRequest#addMaxIndexDocsCondition + // for adding condition to check if the index has at least numDocs. + 1_350_000_000L, 0L, Setting.Property.NodeScope, Setting.Property.Dynamic @@ -275,37 +281,71 @@ private AnomalyDetectorSettings() {} public static final int MULTI_ENTITY_NUM_TREES = 10; - // cache related - public static final int DEDICATED_CACHE_SIZE = 10; + // ====================================== + // cache related parameters + // ====================================== + /* + * Opensearch-only setting + * Each detector has its dedicated cache that stores ten entities' states per node. + * A detector's hottest entities load their states into the dedicated cache. + * Other detectors cannot use space reserved by a detector's dedicated cache. + * DEDICATED_CACHE_SIZE is a setting to make dedicated cache's size flexible. + * When that setting is changed, if the size decreases, we will release memory + * if required (e.g., when a user also decreased AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + * the max memory percentage that AD can use); + * if the size increases, we may reject the setting change if we cannot fulfill + * that request (e.g., when it will uses more memory than allowed for AD). + * + * With compact rcf, rcf with 30 trees and shingle size 4 is of 500KB. + * The recommended max heap size is 32 GB. Even if users use all of the heap + * for AD, the max number of entity model cannot surpass + * 3.2 GB/500KB = 3.2 * 10^10 / 5*10^5 = 6.4 * 10 ^4 + * where 3.2 GB is from 10% memory limit of AD plugin. + * That's why I am using 60_000 as the max limit. + */ + public static final Setting DEDICATED_CACHE_SIZE = Setting + .intSetting("plugins.anomaly_detection.dedicated_cache_size", 10, 0, 60_000, Setting.Property.NodeScope, Setting.Property.Dynamic); // We only keep priority (4 bytes float) in inactive cache. 1 million priorities // take up 4 MB. public static final int MAX_INACTIVE_ENTITIES = 1_000_000; // 1 million insertion costs roughly 1 MB. - public static final int DOOR_KEEPER_MAX_INSERTION = 1_000_000; + public static final int DOOR_KEEPER_FOR_CACHE_MAX_INSERTION = 1_000_000; + + // 100,000 insertions costs roughly 1KB. + public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; public static final double DOOR_KEEPER_FAULSE_POSITIVE_RATE = 0.01; + // clean up door keeper every 60 intervals + public static final int DOOR_KEEPER_MAINTENANCE_FREQ = 60; + // Increase the value will adding pressure to indexing anomaly results and our feature query + // OpenSearch-only setting as previous the legacy default is too low (1000) public static final Setting MAX_ENTITIES_PER_QUERY = Setting .intSetting( "plugins.anomaly_detection.max_entities_per_query", - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - 1, - 100_000_000, + 1_000_000, + 0, + 2_000_000, Setting.Property.NodeScope, Setting.Property.Dynamic ); - // Default number of entities retrieved for Preview API - public static final int DEFAULT_ENTITIES_FOR_PREVIEW = 30; - // Maximum number of entities retrieved for Preview API + // Not using legacy value 30 as default. + // Setting default value to 30 of 2-categorical field detector causes heavy GC + // (half of the time is GC on my 1GB heap machine). This is because we use + // terms aggregation to find the top entities in preview. Terms aggregation + // does not support multiple terms. The current solution is concatenation of + // category fields using painless script, which tugs on memory. + // Default value 10 won't cause heavy GC. + // Since every entity is likely to give some anomalies, 10 is enough. public static final Setting MAX_ENTITIES_FOR_PREVIEW = Setting .intSetting( "plugins.anomaly_detection.max_entities_for_preview", - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + 10, 1, 1000, Setting.Property.NodeScope, @@ -313,10 +353,26 @@ private AnomalyDetectorSettings() {} ); // save partial zero-anomaly grade results after indexing pressure reaching the limit + // Opendistro version has similar setting. I lowered the value to make room + // for INDEX_PRESSURE_HARD_LIMIT. I don't find a floatSetting that has both default + // and fallback values. I want users to use the new default value 0.6 instead of 0.8. + // So do not plan to use the value of legacy setting as fallback. public static final Setting INDEX_PRESSURE_SOFT_LIMIT = Setting .floatSetting( "plugins.anomaly_detection.index_pressure_soft_limit", - LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + 0.6f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // save only error or larger-than-one anomaly grade results after indexing + // pressure reaching the limit + // opensearch-only setting + public static final Setting INDEX_PRESSURE_HARD_LIMIT = Setting + .floatSetting( + "plugins.anomaly_detection.index_pressure_hard_limit", + 0.9f, 0.0f, Setting.Property.NodeScope, Setting.Property.Dynamic @@ -342,22 +398,6 @@ private AnomalyDetectorSettings() {} // number of bulk checkpoints per second public static double CHECKPOINT_BULK_PER_SECOND = 0.02; - // responding to 100 cache misses per second allowed. - // 100 because the get threadpool (the one we need to get checkpoint) queue szie is 1000 - // and we may have 10 concurrent multi-entity detectors. So each detector can use: 1000 / 10 = 100 - // for 1m interval. if the max entity number is 3000 per node, it will need around 30m to get all of them cached - // Thus, for 5m internval, it will need 2.5 hours to cache all of them. for 1hour interval, it will be 30hours. - // but for 1 day interval, it will be 30 days. - public static Setting MAX_CACHE_MISS_HANDLING_PER_SECOND = Setting - .intSetting( - "plugins.anomaly_detection.max_cache_miss_handling_per_second", - LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, - 0, - 1000, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - // Maximum number of batch tasks running on one node. // TODO: performance test and tune the setting. public static final Setting MAX_BATCH_TASK_PER_NODE = Setting @@ -404,4 +444,242 @@ private AnomalyDetectorSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + + // ====================================== + // rate-limiting queue parameters + // ====================================== + // the percentage of heap usage allowed for queues holding small requests + // set it to 0 to disable the queue + public static final Setting COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.cold_entity_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.checkpoint_read_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.entity_cold_start_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // the percentage of heap usage allowed for queues holding large requests + // set it to 0 to disable the queue + public static final Setting CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.checkpoint_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.result_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // expected execution time per cold entity request. This setting controls + // the speed of cold entity requests execution. The larger, the faster, and + // the more performance impact to customers' workload. + public static final Setting EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS = Setting + .intSetting( + "plugins.anomaly_detection.expected_cold_entity_execution_time_in_secs", + 3, + 0, + 3600, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * EntityRequest has entityName (# category fields * 256, the recommended limit + * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest + * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, + * 8 bytes), and priority (12 bytes). + * Plus Java object size (12 bytes), we have roughly 928 bytes per request + * assuming we have 2 categorical fields (plan to support 2 categorical fields now). + * We don't want the total size exceeds 0.1% of the heap. + * We can have at most 0.1% heap / 928 = heap / 928,000. + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 928 = 1078 + */ + public static int ENTITY_REQUEST_SIZE_IN_BYTES = 928; + + /** + * EntityFeatureRequest consists of EntityRequest (928 bytes, read comments + * of ENTITY_COLD_START_QUEUE_SIZE_CONSTANT), pointer to current feature + * (8 bytes), and dataStartTimeMillis (8 bytes). We have roughly + * 928 + 16 = 944 bytes per request. + * + * We don't want the total size exceeds 0.1% of the heap. + * We should have at most 0.1% heap / 944 = heap / 944,000 + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 944 = 1059 + */ + public static int ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES = 944; + + /** + * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * Plus Java object size (12 bytes), we have roughly 1160 bytes per request + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 1148 = heap / 116,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 1160 = 8621 + */ + public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; + + /** + * CheckpointWriteRequest consists of IndexRequest (200 KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * The total is roughly 200 KB per request. + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 200KB = heap / 20,000,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 2.0 * 10^5 = 50 + */ + public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; + + /** + * Max concurrent entity cold starts per node + */ + public static final Setting ENTITY_COLD_START_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.entity_cold_start_queue_concurrency", + 1, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent checkpoint reads per node + */ + public static final Setting CHECKPOINT_READ_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_read_queue_concurrency", + 1, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent checkpoint writes per node + */ + public static final Setting CHECKPOINT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_write_queue_concurrency", + 2, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent result writes per node. Since checkpoint is relatively large + * (250KB), we have 2 concurrent threads processing the queue. + */ + public static final Setting RESULT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.result_write_queue_concurrency", + 2, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting CHECKPOINT_READ_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_read_queue_batch_size", + 25, + 1, + 60, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting CHECKPOINT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_write_queue_batch_size", + 25, + 1, + 60, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each result takes roughly 1KB. 5000 requests are of 5 MB. + */ + public static final Setting RESULT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.result_write_queue_batch_size", + 5000, + 1, + 15000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Duration QUEUE_MAINTENANCE = Duration.ofMinutes(10); + + // we won't accept a checkpoint larger than 10MB. Or we risk OOM. + public static final int MAX_CHECKPOINT_BYTES = 10_000_000; + + public static final float MAX_QUEUED_TASKS_RATIO = 0.5f; + + public static final float MEDIUM_SEGMENT_PRUNE_RATIO = 0.1f; + + public static final float LOW_SEGMENT_PRUNE_RATIO = 0.3f; + + // expensive maintenance (e.g., queue maintenance) with 1/10000 probability + public static final int MAINTENANCE_FREQ_CONSTANT = 10000; + + // ====================================== + // pagination setting + // ====================================== + // pagination size + public static final Setting PAGE_SIZE = Setting + .intSetting("plugins.anomaly_detection.page_size", 1_000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + public static final float INTERVAL_RATIO_FOR_REQUESTS = 0.8f; } diff --git a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java index 7ff395d2a..01295b9b7 100644 --- a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java @@ -31,21 +31,12 @@ import static org.opensearch.common.settings.Setting.Property.Dynamic; import static org.opensearch.common.settings.Setting.Property.NodeScope; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -public class EnabledSetting { - - private static Logger logger = LogManager.getLogger(EnabledSetting.class); +public class EnabledSetting extends AbstractSetting { /** * Singleton instance @@ -63,7 +54,7 @@ public class EnabledSetting { public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled"; - private final Map> settings = unmodifiableMap(new HashMap>() { + private static final Map> settings = unmodifiableMap(new HashMap>() { { Setting LegacyADPluginEnabledSetting = Setting .boolSetting(LEGACY_OPENDISTRO_AD_PLUGIN_ENABLED, true, NodeScope, Dynamic, Deprecated); @@ -91,48 +82,17 @@ public class EnabledSetting { } }); - /** Latest setting value for each registered key. Thread-safe is required. */ - private final Map latestSettings = new ConcurrentHashMap<>(); - - private ClusterService clusterService; - - private EnabledSetting() {} + private EnabledSetting(Map> settings) { + super(settings); + } public static synchronized EnabledSetting getInstance() { if (INSTANCE == null) { - INSTANCE = new EnabledSetting(); + INSTANCE = new EnabledSetting(settings); } return INSTANCE; } - private void setSettingsUpdateConsumers() { - for (Setting setting : settings.values()) { - clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, newVal -> { - logger.info("[AD] The value of setting [{}] changed to [{}]", setting.getKey(), newVal); - latestSettings.put(setting.getKey(), newVal); - }); - } - } - - /** - * Get setting value by key. Return default value if not configured explicitly. - * - * @param key setting key. - * @param Setting type - * @return T setting value or default - */ - @SuppressWarnings("unchecked") - public T getSettingValue(String key) { - return (T) latestSettings.getOrDefault(key, getSetting(key).getDefault(Settings.EMPTY)); - } - - private Setting getSetting(String key) { - if (settings.containsKey(key)) { - return settings.get(key); - } - throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); - } - /** * Whether AD plugin is enabled. If disabled, AD plugin rejects RESTful requests and stop all AD jobs. * @return whether AD plugin is enabled. @@ -148,13 +108,4 @@ public static boolean isADPluginEnabled() { public static boolean isADBreakerEnabled() { return EnabledSetting.getInstance().getSettingValue(EnabledSetting.AD_BREAKER_ENABLED); } - - public void init(ClusterService clusterService) { - this.clusterService = clusterService; - setSettingsUpdateConsumers(); - } - - public List> getSettings() { - return new ArrayList<>(settings.values()); - } } diff --git a/src/main/java/org/opensearch/ad/settings/NumericSetting.java b/src/main/java/org/opensearch/ad/settings/NumericSetting.java new file mode 100644 index 000000000..eed8ac7ec --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/NumericSetting.java @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import static java.util.Collections.unmodifiableMap; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; + +public class NumericSetting extends AbstractSetting { + + /** + * Singleton instance + */ + private static NumericSetting INSTANCE; + + /** + * Settings name + */ + public static final String CATEGORY_FIELD_LIMIT = "plugins.anomaly_detection.category_field_limit"; + + private static final Map> settings = unmodifiableMap(new HashMap>() { + { + // how many categorical fields we support + // The number of category field won't causes correctness issues for our + // implementation, but can cause performance issues. The more categorical + // fields, the larger of the anomaly results, intermediate states, and + // more expensive entities (e.g., to get top entities in preview API, we need + // to use scripts in terms aggregation. The more fields, the slower the query). + put( + CATEGORY_FIELD_LIMIT, + Setting.intSetting(CATEGORY_FIELD_LIMIT, 2, 0, 5, Setting.Property.NodeScope, Setting.Property.Dynamic) + ); + } + }); + + private NumericSetting(Map> settings) { + super(settings); + } + + public static synchronized NumericSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new NumericSetting(settings); + } + return INSTANCE; + } + + /** + * Whether AD plugin is enabled. If disabled, AD plugin rejects RESTful requests and stop all AD jobs. + * @return whether AD plugin is enabled. + */ + public static int maxCategoricalFields() { + return NumericSetting.getInstance().getSettingValue(NumericSetting.CATEGORY_FIELD_LIMIT); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStats.java b/src/main/java/org/opensearch/ad/stats/ADStats.java index b9337f03c..755f5f276 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStats.java +++ b/src/main/java/org/opensearch/ad/stats/ADStats.java @@ -29,28 +29,19 @@ import java.util.HashMap; import java.util.Map; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.util.IndexUtils; - /** * This class is the main entry-point for access to the stats that the AD plugin keeps track of. */ public class ADStats { - private IndexUtils indexUtils; - private ModelManager modelManager; private Map> stats; /** * Constructor * - * @param indexUtils utility to get information about indices - * @param modelManager used to get information about which models are hosted on a particular node * @param stats Map of the stats that are to be kept */ - public ADStats(IndexUtils indexUtils, ModelManager modelManager, Map> stats) { - this.indexUtils = indexUtils; - this.modelManager = modelManager; + public ADStats(Map> stats) { this.stats = stats; } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java index 65cbb37fa..db9564e67 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java @@ -26,8 +26,8 @@ package org.opensearch.ad.stats.suppliers; -import static org.opensearch.ad.ml.ModelState.DETECTOR_ID_KEY; -import static org.opensearch.ad.ml.ModelState.MODEL_ID_KEY; +import static org.opensearch.ad.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.ad.ml.ModelState.LAST_USED_TIME_KEY; import static org.opensearch.ad.ml.ModelState.MODEL_TYPE_KEY; import java.util.ArrayList; @@ -41,6 +41,7 @@ import java.util.stream.Stream; import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.ml.ModelManager; /** @@ -53,7 +54,17 @@ public class ModelsOnNodeSupplier implements Supplier>> /** * Set that contains the model stats that should be exposed. */ - public static Set MODEL_STATE_STAT_KEYS = new HashSet<>(Arrays.asList(MODEL_ID_KEY, DETECTOR_ID_KEY, MODEL_TYPE_KEY)); + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_KEY, + CommonName.DETECTOR_ID_KEY, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY + ) + ); /** * Constructor diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index f7b81d64b..97ecc26d4 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -597,7 +597,8 @@ private void detectAnomaly( error, null, adTask.getDetector().getUser(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT) + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + null ); anomalyResults.add(anomalyResult); } else { @@ -638,7 +639,8 @@ private void detectAnomaly( null, null, adTask.getDetector().getUser(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT) + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + null ); anomalyResults.add(anomalyResult); } diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index f9dfedf52..30e3c6433 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -106,7 +106,7 @@ public synchronized void add(ADTask adTask) { } checkRunningTaskLimit(); long neededCacheSize = calculateADTaskCacheSize(adTask); - if (!memoryTracker.canAllocateReserved(adTask.getDetectorId(), neededCacheSize)) { + if (!memoryTracker.canAllocateReserved(neededCacheSize)) { throw new LimitExceededException("No enough memory to run detector"); } memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java index f06c02342..37c9a1f5c 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java @@ -27,19 +27,18 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.bulk.BulkResponse; import org.opensearch.ad.constant.CommonValue; import org.opensearch.common.settings.Settings; import org.opensearch.transport.TransportRequestOptions; -public class ADResultBulkAction extends ActionType { +public class ADResultBulkAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; public static final ADResultBulkAction INSTANCE = new ADResultBulkAction(); private ADResultBulkAction() { - super(NAME, BulkResponse::new); + super(NAME, ADResultBulkResponse::new); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java new file mode 100644 index 000000000..7d25e3eba --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.opensearch.action.ActionResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +public class ADResultBulkResponse extends ActionResponse { + public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; + + private List retryRequests; + + /** + * + * @param retryRequests a list of requests to retry + */ + public ADResultBulkResponse(List retryRequests) { + this.retryRequests = retryRequests; + } + + public ADResultBulkResponse() { + this.retryRequests = null; + } + + public ADResultBulkResponse(StreamInput in) throws IOException { + int size = in.readInt(); + if (size > 0) { + retryRequests = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + retryRequests.add(new IndexRequest(in)); + } + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (retryRequests == null || retryRequests.size() == 0) { + out.writeInt(0); + } else { + out.writeInt(retryRequests.size()); + for (IndexRequest result : retryRequests) { + result.writeTo(out); + } + } + } + + public boolean hasFailures() { + return retryRequests != null && retryRequests.size() > 0; + } + + public Optional> getRetryRequests() { + return Optional.ofNullable(retryRequests); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java index 83d6245ec..afd29c8a5 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java @@ -26,11 +26,13 @@ package org.opensearch.ad.transport; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.INDEX_PRESSURE_HARD_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; import java.io.IOException; +import java.util.List; import java.util.Locale; import java.util.Random; @@ -45,6 +47,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.util.BulkUtil; import org.opensearch.ad.util.RestHandlerUtils; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -56,14 +59,16 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class ADResultBulkTransportAction extends HandledTransportAction { +public class ADResultBulkTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ADResultBulkTransportAction.class); private IndexingPressure indexingPressure; private final long primaryAndCoordinatingLimits; private float softLimit; + private float hardLimit; private String indexName; private Client client; + private Random random; @Inject public ADResultBulkTransportAction( @@ -78,52 +83,59 @@ public ADResultBulkTransportAction( this.indexingPressure = indexingPressure; this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); this.softLimit = INDEX_PRESSURE_SOFT_LIMIT.get(settings); + this.hardLimit = INDEX_PRESSURE_HARD_LIMIT.get(settings); this.indexName = CommonName.ANOMALY_RESULT_INDEX_ALIAS; this.client = client; clusterService.getClusterSettings().addSettingsUpdateConsumer(INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); + // random seed is 42. Can be any number + this.random = new Random(42); } @Override - protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { + protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { // Concurrent indexing memory limit = 10% of heap // indexing pressure = indexing bytes / indexing limit // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; + List results = request.getAnomalyResults(); + + if (results == null || results.size() < 1) { + listener.onResponse(new ADResultBulkResponse()); + } BulkRequest bulkRequest = new BulkRequest(); if (indexingPressurePercent <= softLimit) { - for (AnomalyResult result : request.getAnomalyResults()) { + for (AnomalyResult result : results) { addResult(bulkRequest, result); } - } else if (Float.compare(indexingPressurePercent, 1.0f) < 0) { - // exceed soft limit (80%) but smaller than hard limit (100%) - // random seed is 42. Can be any number - Random random = new Random(42); + } else if (indexingPressurePercent <= hardLimit) { + // exceed soft limit (60%) but smaller than hard limit (90%) float acceptProbability = 1 - indexingPressurePercent; - for (AnomalyResult result : request.getAnomalyResults()) { - if (result.getAnomalyGrade() > 0 || random.nextFloat() < acceptProbability) { + for (AnomalyResult result : results) { + if (result.isHighPriority() || random.nextFloat() < acceptProbability) { addResult(bulkRequest, result); } } } else { - // if exceeding 100% of hard limit, try our luck and only index non-zero grade result - for (AnomalyResult result : request.getAnomalyResults()) { - if (result.getAnomalyGrade() > 0) { + // if exceeding hard limit, only index non-zero grade or error result + for (AnomalyResult result : results) { + if (result.isHighPriority()) { addResult(bulkRequest, result); } } } if (bulkRequest.numberOfActions() > 0) { - client - .execute( - BulkAction.INSTANCE, - bulkRequest, - ActionListener.wrap(response -> listener.onResponse(response), listener::onFailure) - ); + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { + List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); + listener.onResponse(new ADResultBulkResponse(failedRequests)); + }, listener::onFailure)); + } else { + listener.onResponse(new ADResultBulkResponse()); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java index d270bb70c..9bb78cffd 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java @@ -100,6 +100,7 @@ public void writeTo(StreamOutput out) throws IOException { * @return XContentBuilder * @throws IOException thrown by builder for invalid field */ + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { for (String stat : statsMap.keySet()) { builder.field(stat, statsMap.get(stat)); diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java index 0e255c710..27a338bb8 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java @@ -36,7 +36,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.InputStreamStreamInput; import org.opensearch.common.io.stream.OutputStreamStreamOutput; @@ -103,9 +103,9 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); - builder.field(CommonMessageAttributes.START_JSON_KEY, start); - builder.field(CommonMessageAttributes.END_JSON_KEY, end); + builder.field(CommonName.ID_JSON_KEY, adID); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index f7fff91e8..521c05e0d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -27,10 +27,13 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.constant.CommonErrorMessages.INVALID_SEARCH_QUERY_MSG; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; import java.net.ConnectException; import java.util.ArrayList; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -69,14 +72,16 @@ import org.opensearch.ad.common.exception.ResourceNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.feature.CompositeRetriever; +import org.opensearch.ad.feature.CompositeRetriever.PageIterator; import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.feature.SinglePointFeatures; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelPartitioner; import org.opensearch.ad.ml.RcfResult; import org.opensearch.ad.ml.rcf.CombinedRcfResult; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.FeatureData; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -96,6 +101,7 @@ import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.node.NodeClosedException; import org.opensearch.rest.RestStatus; @@ -119,12 +125,9 @@ public class AnomalyResultTransportAction extends HandledTransportAction hcDetectors; + private NamedXContentRegistry xContentRegistry; + private Settings settings; + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + private final float intervalRatioForRequest; + private int maxEntitiesPerInterval; + private int pageSize; @Inject public AnomalyResultTransportAction( @@ -160,10 +170,11 @@ public AnomalyResultTransportAction( ADCircuitBreakerService adCircuitBreakerService, ADStats adStats, ThreadPool threadPool, - SearchFeatureDao searchFeatureDao + NamedXContentRegistry xContentRegistry ) { super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); this.transportService = transportService; + this.settings = settings; this.client = client; this.stateManager = manager; this.featureManager = featureManager; @@ -180,8 +191,15 @@ public AnomalyResultTransportAction( this.adCircuitBreakerService = adCircuitBreakerService; this.adStats = adStats; this.threadPool = threadPool; - this.searchFeatureDao = searchFeatureDao; this.hcDetectors = new HashSet<>(); + this.xContentRegistry = xContentRegistry; + this.intervalRatioForRequest = AnomalyDetectorSettings.INTERVAL_RATIO_FOR_REQUESTS; + + this.maxEntitiesPerInterval = MAX_ENTITIES_PER_QUERY.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerInterval = it); + + this.pageSize = PAGE_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(PAGE_SIZE, it -> pageSize = it); } /** @@ -213,6 +231,7 @@ public AnomalyResultTransportAction( * + cold start cannot succeed * + unknown prediction error * + memory circuit breaker tripped + * + invalid search query * * Known causes of EndRunException with endNow returning true: * + a model partition's memory size reached limit @@ -220,7 +239,7 @@ public AnomalyResultTransportAction( * + Having trouble querying feature data due to * * index does not exist * * all features have been disabled - * * invalid search query + * * + anomaly detector is not available * + AD plugin is disabled * + training data is invalid due to serious internal bug(s) @@ -278,6 +297,103 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } } + // + + /** + * didn't use ActionListener.wrap so that I can + * 1) use this to refer to the listener inside the listener + * 2) pass parameters using constructors + * + */ + class PageListener implements ActionListener { + private PageIterator pageIterator; + private String detectorId; + private long dataStartTime; + private long dataEndTime; + + PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { + this.pageIterator = pageIterator; + this.detectorId = detectorId; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + @Override + public void onResponse(CompositeRetriever.Page entityFeatures) { + if (entityFeatures != null && false == entityFeatures.isEmpty()) { + // wrap expensive operation inside ad threadpool + threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> { + Set>> node2Entities = entityFeatures + .getResults() + .entrySet() + .stream() + .collect( + Collectors + .groupingBy( + // from entity name to its node + e -> hashRing.getOwningNode(e.getKey().toString()).get(), + Collectors.toMap(Entry::getKey, Entry::getValue) + ) + ) + .entrySet(); + + Iterator>> iterator = node2Entities.iterator(); + + while (iterator.hasNext()) { + Entry> entry = iterator.next(); + DiscoveryNode modelNode = entry.getKey(); + if (modelNode == null) { + iterator.remove(); + continue; + } + String modelNodeId = modelNode.getId(); + if (stateManager.isMuted(modelNodeId)) { + LOG.info(String.format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s", modelNodeId)); + iterator.remove(); + } + } + + final AtomicReference failure = new AtomicReference<>(); + int nodeCount = node2Entities.size(); + AtomicInteger responseCount = new AtomicInteger(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + EntityResultAction.NAME, + new EntityResultRequest(detectorId, nodeEntity.getValue(), dataStartTime, dataEndTime), + option, + new ActionListenerResponseHandler<>( + new EntityResultListener( + node.getId(), + detectorId, + failure, + nodeCount, + pageIterator, + this, + responseCount + ), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); + }); + } + } + + @Override + public void onFailure(Exception e) { + Exception convertedException = convertedQueryFailureException(e, detectorId); + if (false == (convertedException instanceof AnomalyDetectionException)) { + Throwable cause = ExceptionsHelper.unwrapCause(convertedException); + convertedException = new InternalFailure(detectorId, cause); + } + stateManager.setException(detectorId, convertedException); + } + } + private ActionListener> onGetDetector( ActionListener listener, String adID, @@ -304,88 +420,62 @@ private ActionListener> onGetDetector( List categoryField = anomalyDetector.getCategoryField(); if (categoryField != null) { - Optional previousException = stateManager.fetchColdStartException(adID); + Optional previousException = stateManager.fetchExceptionAndClear(adID); if (previousException.isPresent()) { Exception exception = previousException.get(); LOG.error("Previous exception of {}: {}", adID, exception); if (exception instanceof EndRunException) { - listener.onFailure(exception); EndRunException endRunException = (EndRunException) exception; if (endRunException.isEndNow()) { + listener.onFailure(exception); return; } } } - ActionListener> getEntityFeatureslistener = ActionListener.wrap(entityFeatures -> { - if (entityFeatures.isEmpty()) { - // Feature not available is common when we have data holes. Respond empty response - // so that alerting will not print stack trace to avoid bloating our logs. - LOG.info("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse( - Double.NaN, - Double.NaN, - Double.NaN, - new ArrayList(), - "No data in current detection window" - ) - ); - } else { - Set>> node2Entities = entityFeatures - .entrySet() - .stream() - .collect( - Collectors - .groupingBy( - e -> hashRing.getOwningNode(e.getKey()).get(), - Collectors.toMap(Entry::getKey, Entry::getValue) - ) - ) - .entrySet(); - - int nodeCount = node2Entities.size(); - AtomicInteger responseCount = new AtomicInteger(); - - final AtomicReference failure = new AtomicReference<>(); - node2Entities.stream().forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime), - this.option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), adID, responseCount, nodeCount, failure, listener), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); - } + // assume request are in epoch milliseconds + long nextDetectionStartTime = request.getEnd() + (long) (anomalyDetector.getDetectorIntervalInMilliseconds() + * intervalRatioForRequest); - }, exception -> handleFailure(exception, listener, adID)); - - threadPool - .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> searchFeatureDao - .getFeaturesByEntities( - anomalyDetector, - dataStartTime, - dataEndTime, - new ThreadedActionListener<>( - LOG, - threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, - getEntityFeatureslistener, - false - ) - ) - ); + CompositeRetriever compositeRetriever = new CompositeRetriever( + dataStartTime, + dataEndTime, + anomalyDetector, + xContentRegistry, + client, + nextDetectionStartTime, + settings, + maxEntitiesPerInterval, + pageSize + ); + + PageIterator pageIterator = null; + + try { + pageIterator = compositeRetriever.iterator(); + } catch (Exception e) { + listener + .onFailure( + new EndRunException(anomalyDetector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, false) + ); + return; + } + + PageListener getEntityFeatureslistener = new PageListener(pageIterator, adID, dataStartTime, dataEndTime); + + if (pageIterator.hasNext()) { + pageIterator.next(getEntityFeatureslistener); + } + + // We don't know when the pagination will not finish. To not + // block the following interval request to start, we return immediately. + // Pagination will stop itself when the time is up. + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else { + listener.onResponse(new AnomalyResultResponse(Double.NaN, Double.NaN, Double.NaN, new ArrayList())); + } return; } @@ -517,20 +607,48 @@ private ActionListener onFeatureResponse( new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) ); } - }, exception -> { handleFailure(exception, listener, adID); }); + }, exception -> { handleQueryFailure(exception, listener, adID); }); } - private void handleFailure(Exception exception, ActionListener listener, String adID) { - if (exception instanceof IndexNotFoundException) { - listener.onFailure(new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), true).countedInStats(false)); - } else if (exception instanceof EndRunException) { + private void handleQueryFailure(Exception exception, ActionListener listener, String adID) { + Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); + + if (convertedQueryFailureException instanceof EndRunException) { // invalid feature query - listener.onFailure(exception); + listener.onFailure(convertedQueryFailureException); } else { - handleExecuteException(exception, listener, adID); + handleExecuteException(convertedQueryFailureException, listener, adID); } } + /** + * Convert a query related exception to EndRunException + * + * These query exception can happen during the starting phase of the OpenSearch + * process. Thus, set the stopNow parameter of these EndRunException to false + * and confirm the EndRunException is not a false positive. + * + * @param exception Exception + * @param adID detector Id + * @return the converted exception if the exception is query related + */ + private Exception convertedQueryFailureException(Exception exception, String adID) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + return new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false).countedInStats(false); + } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { + // This is to catch invalid aggregation on wrong field type. For example, + // sum aggregation on text field. We should end detector run for such case. + return new EndRunException( + adID, + INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), + exception, + false + ).countedInStats(false); + } + + return exception; + } + /** * Verify failure of rcf or threshold models. If there is no model, trigger cold * start. If there is an exception for the previous cold start of this detector, @@ -568,7 +686,7 @@ private AnomalyDetectionException coldStartIfNoModel(AtomicReference previousException = stateManager.fetchColdStartException(adID); + final Optional previousException = stateManager.fetchExceptionAndClear(adID); if (previousException.isPresent()) { Exception exception = previousException.get(); LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); @@ -602,7 +720,7 @@ private void findException(Throwable cause, String adID, AtomicReference listener.onFailure(ex); } else if (ex instanceof AnomalyDetectionException) { listener.onFailure(new InternalFailure((AnomalyDetectionException) ex)); - } else if (ex instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) ex)) { - // This is to catch invalid aggregation on wrong field type. For example, - // sum aggregation on text field. We should end detector run for such case. - listener - .onFailure( - new EndRunException( - adID, - INVALID_SEARCH_QUERY_MSG + ((SearchPhaseExecutionException) ex).getDetailedMessage(), - ex, - true - ).countedInStats(false) - ); } else { Throwable cause = ExceptionsHelper.unwrapCause(ex); listener.onFailure(new InternalFailure(adID, cause)); @@ -885,7 +991,7 @@ private void handleConnectionException(String node) { if (!nodes.nodeExists(node) && hashRing.build()) { return; } - // rebuilt is not done or node is unresponsive + // rebuilding is not done or node is unresponsive stateManager.addPressure(node); } @@ -971,26 +1077,20 @@ private void coldStart(AnomalyDetector detector) { .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { if (exception instanceof AnomalyDetectionException) { // e.g., partitioned model exceeds memory limit - stateManager.setLastColdStartException(detectorId, (AnomalyDetectionException) exception); + stateManager.setException(detectorId, exception); } else if (exception instanceof IllegalArgumentException) { // IllegalArgumentException due to invalid training data stateManager - .setLastColdStartException( - detectorId, - new EndRunException(detectorId, "Invalid training data", exception, false) - ); + .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); } else if (exception instanceof OpenSearchTimeoutException) { stateManager - .setLastColdStartException( + .setException( detectorId, new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) ); } else { stateManager - .setLastColdStartException( - detectorId, - new EndRunException(detectorId, "Error while training model", exception, false) - ); + .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); } }); @@ -1001,21 +1101,16 @@ private void coldStart(AnomalyDetector detector) { new ThreadedActionListener<>(LOG, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, trainModelListener, false) ); } else { - stateManager.setLastColdStartException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); + stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); } }, exception -> { if (exception instanceof OpenSearchTimeoutException) { - stateManager - .setLastColdStartException( - detectorId, - new InternalFailure(detectorId, "Time out while getting training data", exception) - ); + stateManager.setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); } else if (exception instanceof AnomalyDetectionException) { // e.g., Invalid search query - stateManager.setLastColdStartException(detectorId, (AnomalyDetectionException) exception); + stateManager.setException(detectorId, exception); } else { - stateManager - .setLastColdStartException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); + stateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); } }); @@ -1048,7 +1143,7 @@ private void coldStart(AnomalyDetector detector) { private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { String detectorId = detector.getDetectorId(); - Optional previousException = stateManager.fetchColdStartException(detectorId); + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); if (previousException.isPresent()) { Exception exception = previousException.get(); @@ -1071,7 +1166,7 @@ private Optional coldStartIfNoCheckPoint(AnomalyDetec } else { String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); LOG.error(errorMsg, exception); - stateManager.setLastColdStartException(detectorId, new AnomalyDetectionException(errorMsg, exception)); + stateManager.setException(detectorId, new AnomalyDetectionException(errorMsg, exception)); } })); @@ -1081,75 +1176,67 @@ private Optional coldStartIfNoCheckPoint(AnomalyDetec class EntityResultListener implements ActionListener { private String nodeId; private final String adID; - private AtomicInteger responseCount; - private int nodeCount; - private ActionListener listener; - private List ackResponses; private AtomicReference failure; + private int nodeCount; + private AtomicInteger responseCount; + private PageIterator pageIterator; + private PageListener pageListener; EntityResultListener( String nodeId, String adID, - AtomicInteger responseCount, - int nodeCount, AtomicReference failure, - ActionListener listener + int nodeCount, + PageIterator pageIterator, + PageListener pageListener, + AtomicInteger responseCount ) { this.nodeId = nodeId; this.adID = adID; - this.responseCount = responseCount; - this.nodeCount = nodeCount; this.failure = failure; - this.listener = listener; - this.ackResponses = new ArrayList<>(); + this.nodeCount = nodeCount; + this.pageIterator = pageIterator; + this.responseCount = responseCount; + this.pageListener = pageListener; } @Override public void onResponse(AcknowledgedResponse response) { try { - stateManager.resetBackpressureCounter(nodeId); if (response.isAcknowledged() == false) { LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); stateManager.addPressure(nodeId); } else { - ackResponses.add(response); + stateManager.resetBackpressureCounter(nodeId); } } catch (Exception ex) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { - if (nodeCount == responseCount.incrementAndGet()) { - handleEntityResponses(); + if (nodeCount == responseCount.incrementAndGet() && pageIterator.hasNext()) { + pageIterator.next(pageListener); } } } @Override public void onFailure(Exception e) { - if (e == null) { - return; - } try { + // e.g., we have connection issues with all of the nodes while restarting clusters LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); handlePredictionFailure(e, adID, nodeId, failure); + if (failure.get() != null) { + stateManager.setException(adID, failure.get()); + } + } catch (Exception ex) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { - if (nodeCount == responseCount.incrementAndGet()) { - handleEntityResponses(); + if (nodeCount == responseCount.incrementAndGet() && pageIterator.hasNext()) { + pageIterator.next(pageListener); } } } - - private void handleEntityResponses() { - if (failure.get() != null) { - listener.onFailure(failure.get()); - } else if (ackResponses.isEmpty()) { - listener.onFailure(new InternalFailure(adID, NO_ACK_ERR)); - } else { - listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList())); - } - } } } diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java index 4150b2e87..920789837 100644 --- a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java @@ -29,12 +29,16 @@ import java.io.IOException; import java.util.List; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.ModelManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -43,11 +47,12 @@ import org.opensearch.transport.TransportService; public class CronTransportAction extends TransportNodesAction { - + private final Logger LOG = LogManager.getLogger(CronTransportAction.class); private NodeStateManager transportStateManager; private ModelManager modelManager; private FeatureManager featureManager; private CacheProvider cacheProvider; + private EntityColdStarter entityColdStarter; @Inject public CronTransportAction( @@ -58,7 +63,8 @@ public CronTransportAction( NodeStateManager tarnsportStatemanager, ModelManager modelManager, FeatureManager featureManager, - CacheProvider cacheProvider + CacheProvider cacheProvider, + EntityColdStarter entityColdStarter ) { super( CronAction.NAME, @@ -75,6 +81,7 @@ public CronTransportAction( this.modelManager = modelManager; this.featureManager = featureManager; this.cacheProvider = cacheProvider; + this.entityColdStarter = entityColdStarter; } @Override @@ -105,7 +112,8 @@ protected CronNodeResponse nodeOperation(CronNodeRequest request) { // makes checkpoints for hosted models and stop hosting models not actively // used. // for single-entity detector - modelManager.maintenance(); + modelManager + .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining model", e))); // for multi-entity detector cacheProvider.get().maintenance(); @@ -115,6 +123,8 @@ protected CronNodeResponse nodeOperation(CronNodeRequest request) { // delete unused transport state transportStateManager.maintenance(); + entityColdStarter.maintenance(); + return new CronNodeResponse(clusterService.localNode()); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java index a7915a962..3fd827f89 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java @@ -33,7 +33,7 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; @@ -84,7 +84,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); + builder.field(CommonName.ID_JSON_KEY, adID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java index 069d320e8..e443c2dd1 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java @@ -35,7 +35,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; @@ -47,13 +48,13 @@ public class EntityProfileRequest extends ActionRequest implements ToXContentObj public static final String ENTITY = "entity"; public static final String PROFILES = "profiles"; private String adID; - private String entityValue; + private Entity entityValue; private Set profilesToCollect; public EntityProfileRequest(StreamInput in) throws IOException { super(in); adID = in.readString(); - entityValue = in.readString(); + entityValue = new Entity(in); int size = in.readVInt(); profilesToCollect = new HashSet(); if (size != 0) { @@ -63,7 +64,7 @@ public EntityProfileRequest(StreamInput in) throws IOException { } } - public EntityProfileRequest(String adID, String entityValue, Set profilesToCollect) { + public EntityProfileRequest(String adID, Entity entityValue, Set profilesToCollect) { super(); this.adID = adID; this.entityValue = entityValue; @@ -74,7 +75,7 @@ public String getAdID() { return adID; } - public String getEntityValue() { + public Entity getEntityValue() { return entityValue; } @@ -86,7 +87,7 @@ public Set getProfilesToCollect() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(adID); - out.writeString(entityValue); + entityValue.writeTo(out); out.writeVInt(profilesToCollect.size()); for (EntityProfileName profile : profilesToCollect) { out.writeEnum(profile); @@ -99,7 +100,7 @@ public ActionRequestValidationException validate() { if (Strings.isEmpty(adID)) { validationException = addValidationError(CommonErrorMessages.AD_ID_MISSING_MSG, validationException); } - if (Strings.isEmpty(entityValue)) { + if (entityValue == null) { validationException = addValidationError("Entity value is missing", validationException); } if (profilesToCollect == null || profilesToCollect.isEmpty()) { @@ -111,7 +112,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); + builder.field(CommonName.ID_JSON_KEY, adID); builder.field(ENTITY, entityValue); builder.field(PROFILES, profilesToCollect); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java index 49998d308..8205ab678 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java @@ -34,7 +34,7 @@ import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.action.ActionResponse; import org.opensearch.ad.constant.CommonName; -import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.ToXContentObject; @@ -47,13 +47,13 @@ public class EntityProfileResponse extends ActionResponse implements ToXContentO private final Boolean isActive; private final long lastActiveMs; private final long totalUpdates; - private final ModelProfile modelProfile; + private final ModelProfileOnNode modelProfile; public static class Builder { private Boolean isActive = null; private long lastActiveMs = -1L; private long totalUpdates = -1L; - private ModelProfile modelProfile = null; + private ModelProfileOnNode modelProfile = null; public Builder() {} @@ -72,7 +72,7 @@ public Builder setTotalUpdates(long totalUpdates) { return this; } - public Builder setModelProfile(ModelProfile modelProfile) { + public Builder setModelProfile(ModelProfileOnNode modelProfile) { this.modelProfile = modelProfile; return this; } @@ -82,7 +82,7 @@ public EntityProfileResponse build() { } } - public EntityProfileResponse(Boolean isActive, long lastActiveTimeMs, long totalUpdates, ModelProfile modelProfile) { + public EntityProfileResponse(Boolean isActive, long lastActiveTimeMs, long totalUpdates, ModelProfileOnNode modelProfile) { this.isActive = isActive; this.lastActiveMs = lastActiveTimeMs; this.totalUpdates = totalUpdates; @@ -95,7 +95,7 @@ public EntityProfileResponse(StreamInput in) throws IOException { lastActiveMs = in.readLong(); totalUpdates = in.readLong(); if (in.readBoolean()) { - modelProfile = new ModelProfile(in); + modelProfile = new ModelProfileOnNode(in); } else { modelProfile = null; } @@ -113,7 +113,7 @@ public long getTotalUpdates() { return totalUpdates; } - public ModelProfile getModelProfile() { + public ModelProfileOnNode getModelProfile() { return modelProfile; } diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java index 226552f01..6e9384519 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java @@ -40,9 +40,10 @@ import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -63,10 +64,10 @@ public class EntityProfileTransportAction extends HandledTransportAction listener) { String adID = request.getAdID(); - String entityValue = request.getEntityValue(); - String modelId = modelManager.getEntityModelId(adID, entityValue); - Optional node = hashRing.getOwningNode(modelId); - if (!node.isPresent()) { + Entity entityValue = request.getEntityValue(); + Optional modelIdOptional = entityValue.getModelId(adID); + if (false == modelIdOptional.isPresent()) { + listener.onFailure(new AnomalyDetectionException(adID, NO_MODEL_ID_FOUND_MSG)); + return; + } + // we use entity value (e.g., app_0) to find its node + // This should be consistent with how we land a model node in AnomalyResultTransportAction + Optional node = hashRing.getOwningNode(entityValue.toString()); + if (false == node.isPresent()) { listener.onFailure(new AnomalyDetectionException(adID, NO_NODE_FOUND_MSG)); return; } String nodeId = node.get().getId(); + String modelId = modelIdOptional.get(); DiscoveryNode localNode = clusterService.localNode(); if (localNode.getId().equals(nodeId)) { EntityCache cache = cacheProvider.get(); @@ -121,9 +127,9 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener builder.setTotalUpdates(cache.getTotalUpdates(adID, modelId)); } if (profilesToCollect.contains(EntityProfileName.MODELS)) { - long modelSize = cache.getModelSize(adID, modelId); - if (modelSize > 0) { - builder.setModelProfile(new ModelProfile(modelId, modelSize, localNode.getId())); + Optional modleProfile = cache.getModelProfile(adID, modelId); + if (modleProfile.isPresent()) { + builder.setModelProfile(new ModelProfileOnNode(nodeId, modleProfile.get())); } } listener.onResponse(builder.build()); diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java index edc099255..763c3c2cc 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java @@ -32,10 +32,13 @@ import java.util.Locale; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.model.Entity; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -43,21 +46,22 @@ import org.opensearch.common.xcontent.XContentBuilder; public class EntityResultRequest extends ActionRequest implements ToXContentObject { + private static final Logger LOG = LogManager.getLogger(EntityResultRequest.class); private String detectorId; - private Map entities; + private Map entities; private long start; private long end; public EntityResultRequest(StreamInput in) throws IOException { super(in); this.detectorId = in.readString(); - this.entities = in.readMap(StreamInput::readString, StreamInput::readDoubleArray); + this.entities = in.readMap(Entity::new, StreamInput::readDoubleArray); this.start = in.readLong(); this.end = in.readLong(); } - public EntityResultRequest(String detectorId, Map entities, long start, long end) { + public EntityResultRequest(String detectorId, Map entities, long start, long end) { super(); this.detectorId = detectorId; this.entities = entities; @@ -69,7 +73,7 @@ public String getDetectorId() { return this.detectorId; } - public Map getEntities() { + public Map getEntities() { return this.entities; } @@ -85,7 +89,7 @@ public long getEnd() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.detectorId); - out.writeMap(this.entities, StreamOutput::writeString, StreamOutput::writeDoubleArray); + out.writeMap(entities, (s, e) -> e.writeTo(s), StreamOutput::writeDoubleArray); out.writeLong(this.start); out.writeLong(this.end); } @@ -108,12 +112,19 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, detectorId); - builder.field(CommonMessageAttributes.START_JSON_KEY, start); - builder.field(CommonMessageAttributes.END_JSON_KEY, end); - for (String entity : entities.keySet()) { - builder.field(entity, entities.get(entity)); + builder.field(CommonName.ID_JSON_KEY, detectorId); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.startArray(CommonName.ENTITIES_JSON_KEY); + for (final Map.Entry entry : entities.entrySet()) { + if (entry.getKey() != null) { + builder.startObject(); + builder.field(CommonName.ENTITY_KEY, entry.getKey()); + builder.field(CommonName.VALUE_JSON_KEY, entry.getValue()); + builder.endObject(); + } } + builder.endArray(); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java index 79dd48570..04d352127 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java @@ -26,15 +26,15 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; - -import java.time.Clock; -import java.time.Duration; import java.time.Instant; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -42,15 +42,16 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.AnomalyDetectionIndices; -import org.opensearch.ad.ml.CheckpointDao; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelState; @@ -58,26 +59,50 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.Entity; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityFeatureRequest; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.ParseUtils; import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +/** + * Entry-point for HCAD workflow. We have created multiple queues for coordinating + * the workflow. The overrall workflow is: + * 1. We store as many frequently used entity models in a cache as allowed by the + * memory limit (10% heap). If an entity feature is a hit, we use the in-memory model + * to detect anomalies and record results using the result write queue. + * 2. If an entity feature is a miss, we check if there is free memory or any other + * entity's model can be evacuated. An in-memory entity's frequency may be lower + * compared to the cache miss entity. If that's the case, we replace the lower + * frequency entity's model with the higher frequency entity's model. To load the + * higher frequency entity's model, we first check if a model exists on disk by + * sending a checkpoint read queue request. If there is a checkpoint, we load it + * to memory, perform detection, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the checkpoint + * write queue. + * 3. We also have the cold entity queue configured for cold entities, and the model + * training and inference are connected by serial juxtaposition to limit resource usage. + */ public class EntityResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class); - private ModelManager manager; + private ModelManager modelManager; private ADCircuitBreakerService adCircuitBreakerService; - private MultiEntityResultHandler anomalyResultHandler; - private CheckpointDao checkpointDao; private CacheProvider cache; private final NodeStateManager stateManager; - private final int coolDownMinutes; - private final Clock clock; private AnomalyDetectionIndices indexUtil; + private ResultWriteWorker resultWriteQueue; + private CheckpointReadWorker checkpointReadQueue; + private ColdEntityWorker coldEntityQueue; + private ThreadPool threadPool; @Inject public EntityResultTransportAction( @@ -85,56 +110,30 @@ public EntityResultTransportAction( TransportService transportService, ModelManager manager, ADCircuitBreakerService adCircuitBreakerService, - MultiEntityResultHandler anomalyResultHandler, - CheckpointDao checkpointDao, - CacheProvider entityCache, - NodeStateManager stateManager, - Settings settings, - AnomalyDetectionIndices indexUtil - ) { - this( - actionFilters, - transportService, - manager, - adCircuitBreakerService, - anomalyResultHandler, - checkpointDao, - entityCache, - stateManager, - settings, - Clock.systemUTC(), - indexUtil - ); - } - - protected EntityResultTransportAction( - ActionFilters actionFilters, - TransportService transportService, - ModelManager manager, - ADCircuitBreakerService adCircuitBreakerService, - MultiEntityResultHandler anomalyResultHandler, - CheckpointDao checkpointDao, CacheProvider entityCache, NodeStateManager stateManager, - Settings settings, - Clock clock, - AnomalyDetectionIndices indexUtil + AnomalyDetectionIndices indexUtil, + ResultWriteWorker resultWriteQueue, + CheckpointReadWorker checkpointReadQueue, + ColdEntityWorker coldEntityQueue, + ThreadPool threadPool ) { super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); - this.manager = manager; + this.modelManager = manager; this.adCircuitBreakerService = adCircuitBreakerService; - this.anomalyResultHandler = anomalyResultHandler; - this.checkpointDao = checkpointDao; this.cache = entityCache; this.stateManager = stateManager; - this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); - this.clock = clock; this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.threadPool = threadPool; } @Override protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { if (adCircuitBreakerService.isOpen()) { + threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); listener .onFailure(new LimitExceededException(request.getDetectorId(), CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); return; @@ -142,18 +141,35 @@ protected void doExecute(Task task, EntityResultRequest request, ActionListener< try { String detectorId = request.getDetectorId(); - stateManager.getAnomalyDetector(detectorId, onGetDetector(listener, detectorId, request)); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, detectorId); + } + + stateManager.getAnomalyDetector(detectorId, onGetDetector(listener, detectorId, request, previousException)); } catch (Exception exception) { LOG.error("fail to get entity's anomaly grade", exception); listener.onFailure(exception); } - } private ActionListener> onGetDetector( ActionListener listener, String detectorId, - EntityResultRequest request + EntityResultRequest request, + Optional prevException ) { return ActionListener.wrap(detectorOptional -> { if (!detectorOptional.isPresent()) { @@ -162,65 +178,121 @@ private ActionListener> onGetDetector( } AnomalyDetector detector = detectorOptional.get(); - // we only support 1 categorical field now - String categoricalField = detector.getCategoryField().get(0); - ADResultBulkRequest currentBulkRequest = new ADResultBulkRequest(); - // index pressure is high. Only save anomalies - boolean onlySaveAnomalies = stateManager - .getLastIndexThrottledTime() - .plus(Duration.ofMinutes(coolDownMinutes)) - .isAfter(clock.instant()); + if (request.getEntities() == null) { + listener.onResponse(null); + return; + } Instant executionStartTime = Instant.now(); - for (Entry entity : request.getEntities().entrySet()) { - String entityName = entity.getKey(); - // For ES, the limit of the document ID is 512 bytes. - // skip an entity if the entity's name is more than 256 characters - // since we are using it as part of document id. - if (entityName.length() > AnomalyDetectorSettings.MAX_ENTITY_LENGTH) { + Map cacheMissEntities = new HashMap<>(); + for (Entry entityEntry : request.getEntities().entrySet()) { + Entity categoricalValues = entityEntry.getKey(); + + Optional modelIdOptional = categoricalValues.getModelId(detectorId); + if (false == modelIdOptional.isPresent()) { continue; } - double[] datapoint = entity.getValue(); - String modelId = manager.getEntityModelId(detectorId, entityName); - ModelState entityModel = cache.get().get(modelId, detector, datapoint, entityName); + String modelId = modelIdOptional.get(); + double[] datapoint = entityEntry.getValue(); + ModelState entityModel = cache.get().get(modelId, detector); if (entityModel == null) { // cache miss + cacheMissEntities.put(categoricalValues, datapoint); continue; } - ThresholdingResult result = manager.getAnomalyResultForEntity(detectorId, datapoint, entityName, entityModel, modelId); + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(datapoint, entityModel, modelId, detector, categoricalValues); // result.getRcfScore() = 0 means the model is not initialized // result.getGrade() = 0 means it is not an anomaly // So many OpenSearchRejectedExecutionException if we write no matter what - if (result.getRcfScore() > 0 && (!onlySaveAnomalies || result.getGrade() > 0)) { - currentBulkRequest - .add( - new AnomalyResult( + if (result.getRcfScore() > 0) { + resultWriteQueue + .put( + new ResultWriteRequest( + System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, - result.getRcfScore(), - result.getGrade(), - result.getConfidence(), - ParseUtils.getFeatureData(datapoint, detector), - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - executionStartTime, - Instant.now(), - null, - Arrays.asList(new Entity(categoricalField, entityName)), - detector.getUser(), - indexUtil.getSchemaVersion(ADIndex.RESULT) + result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + new AnomalyResult( + detectorId, + null, + result.getRcfScore(), + result.getGrade(), + result.getConfidence(), + ParseUtils.getFeatureData(datapoint, detector), + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + null, + categoricalValues, + detector.getUser(), + indexUtil.getSchemaVersion(ADIndex.RESULT), + modelId + ) ) ); } } - if (currentBulkRequest.numberOfActions() > 0) { - this.anomalyResultHandler.flush(currentBulkRequest, detectorId); + + // split hot and cold entities + Pair, List> hotColdEntities = cache + .get() + .selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector); + + List hotEntityRequests = new ArrayList<>(); + List coldEntityRequests = new ArrayList<>(); + + for (Entity hotEntity : hotColdEntities.getLeft()) { + double[] hotEntityValue = cacheMissEntities.get(hotEntity); + if (hotEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); + continue; + } + hotEntityRequests + .add( + new EntityFeatureRequest( + System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), + detectorId, + // hot entities has MEDIUM priority + RequestPriority.MEDIUM, + hotEntity, + hotEntityValue, + request.getStart() + ) + ); + } + + for (Entity coldEntity : hotColdEntities.getRight()) { + double[] coldEntityValue = cacheMissEntities.get(coldEntity); + if (coldEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); + continue; + } + coldEntityRequests + .add( + new EntityFeatureRequest( + System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), + detectorId, + // cold entities has LOW priority + RequestPriority.LOW, + coldEntity, + coldEntityValue, + request.getStart() + ) + ); } - // bulk all accumulated checkpoint requests - this.checkpointDao.flush(); - listener.onResponse(new AcknowledgedResponse(true)); + checkpointReadQueue.putAll(hotEntityRequests); + coldEntityQueue.putAll(coldEntityRequests); + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } }, exception -> { LOG .error( diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java index c53775dc0..6b390b61b 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java @@ -30,6 +30,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.Entity; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -42,7 +43,7 @@ public class GetAnomalyDetectorRequest extends ActionRequest { private String typeStr; private String rawPath; private boolean all; - private String entityValue; + private Entity entity; public GetAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -54,7 +55,7 @@ public GetAnomalyDetectorRequest(StreamInput in) throws IOException { rawPath = in.readString(); all = in.readBoolean(); if (in.readBoolean()) { - entityValue = in.readString(); + entity = new Entity(in); } } @@ -66,7 +67,7 @@ public GetAnomalyDetectorRequest( String typeStr, String rawPath, boolean all, - String entityValue + Entity entity ) { super(); this.detectorID = detectorID; @@ -76,7 +77,7 @@ public GetAnomalyDetectorRequest( this.typeStr = typeStr; this.rawPath = rawPath; this.all = all; - this.entityValue = entityValue; + this.entity = entity; } public String getDetectorID() { @@ -107,8 +108,8 @@ public boolean isAll() { return all; } - public String getEntityValue() { - return entityValue; + public Entity getEntity() { + return entity; } @Override @@ -121,9 +122,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(typeStr); out.writeString(rawPath); out.writeBoolean(all); - if (this.entityValue != null) { + if (this.entity != null) { out.writeBoolean(true); - out.writeString(entityValue); + entity.writeTo(out); } else { out.writeBoolean(false); } diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 0507fbe5c..35aa7f3b5 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -59,6 +59,7 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; @@ -159,14 +160,14 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); EntityProfileRunner profileRunner = new EntityProfileRunner( client, @@ -176,7 +177,7 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener modelSize; private int shingleSize; private long activeEntities; private long totalUpdates; + private List modelProfiles; /** * Constructor @@ -63,6 +63,9 @@ public ProfileNodeResponse(StreamInput in) throws IOException { shingleSize = in.readInt(); activeEntities = in.readVLong(); totalUpdates = in.readVLong(); + if (in.readBoolean()) { + modelProfiles = in.readList(ModelProfile::new); + } } /** @@ -73,13 +76,22 @@ public ProfileNodeResponse(StreamInput in) throws IOException { * @param shingleSize shingle size * @param activeEntity active entity count * @param totalUpdates RCF model total updates + * @param modelProfiles a collection of model profiles like model size */ - public ProfileNodeResponse(DiscoveryNode node, Map modelSize, int shingleSize, long activeEntity, long totalUpdates) { + public ProfileNodeResponse( + DiscoveryNode node, + Map modelSize, + int shingleSize, + long activeEntity, + long totalUpdates, + List modelProfiles + ) { super(node); this.modelSize = modelSize; this.shingleSize = shingleSize; this.activeEntities = activeEntity; this.totalUpdates = totalUpdates; + this.modelProfiles = modelProfiles; } /** @@ -106,6 +118,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(shingleSize); out.writeVLong(activeEntities); out.writeVLong(totalUpdates); + if (modelProfiles != null) { + out.writeBoolean(true); + out.writeList(modelProfiles); + } else { + out.writeBoolean(false); + } } /** @@ -118,7 +136,7 @@ public void writeTo(StreamOutput out) throws IOException { */ @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(MODEL_SIZE_IN_BYTES); + builder.startObject(CommonName.MODEL_SIZE_IN_BYTES); for (Map.Entry entry : modelSize.entrySet()) { builder.field(entry.getKey(), entry.getValue()); } @@ -128,6 +146,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CommonName.ACTIVE_ENTITIES, activeEntities); builder.field(CommonName.TOTAL_UPDATES, totalUpdates); + builder.startArray(CommonName.MODELS); + for (ModelProfile modelProfile : modelProfiles) { + builder.startObject(); + modelProfile.toXContent(builder, params); + builder.endObject(); + } + builder.endArray(); + return builder; } @@ -146,4 +172,8 @@ public long getActiveEntities() { public long getTotalUpdates() { return totalUpdates; } + + public List getModelProfiles() { + return modelProfiles; + } } diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java index 4e77650e7..09a6d59aa 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java @@ -31,10 +31,13 @@ import java.util.List; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -45,6 +48,7 @@ * This class consists of the aggregated responses from the nodes */ public class ProfileResponse extends BaseNodesResponse implements ToXContentFragment { + private static final Logger LOG = LogManager.getLogger(ProfileResponse.class); // filed name in toXContent static final String COORDINATING_NODE = CommonName.COORDINATING_NODE; static final String SHINGLE_SIZE = CommonName.SHINGLE_SIZE; @@ -53,7 +57,7 @@ public class ProfileResponse extends BaseNodesResponse impl static final String MODELS = CommonName.MODELS; static final String TOTAL_UPDATES = CommonName.TOTAL_UPDATES; - private ModelProfile[] modelProfile; + private ModelProfileOnNode[] modelProfile; private int shingleSize; private String coordinatingNode; private long totalSizeInBytes; @@ -69,9 +73,9 @@ public class ProfileResponse extends BaseNodesResponse impl public ProfileResponse(StreamInput in) throws IOException { super(in); int size = in.readVInt(); - modelProfile = new ModelProfile[size]; + modelProfile = new ModelProfileOnNode[size]; for (int i = 0; i < size; i++) { - modelProfile[i] = new ModelProfile(in); + modelProfile[i] = new ModelProfileOnNode(in); } shingleSize = in.readInt(); coordinatingNode = in.readString(); @@ -93,7 +97,7 @@ public ProfileResponse(ClusterName clusterName, List nodes, activeEntities = 0L; totalUpdates = 0L; shingleSize = -1; - List modelProfileList = new ArrayList<>(); + List modelProfileList = new ArrayList<>(); for (ProfileNodeResponse response : nodes) { String curNodeId = response.getNode().getId(); if (response.getShingleSize() >= 0) { @@ -103,7 +107,16 @@ public ProfileResponse(ClusterName clusterName, List nodes, if (response.getModelSize() != null) { for (Map.Entry entry : response.getModelSize().entrySet()) { totalSizeInBytes += entry.getValue(); - modelProfileList.add(new ModelProfile(entry.getKey(), entry.getValue(), curNodeId)); + } + } + if (response.getModelProfiles() != null && response.getModelProfiles().size() > 0) { + for (ModelProfile profile : response.getModelProfiles()) { + modelProfileList.add(new ModelProfileOnNode(curNodeId, profile)); + } + } else if (response.getModelSize() != null && response.getModelSize().size() > 0) { + for (Map.Entry entry : response.getModelSize().entrySet()) { + // single-stream detectors have no entity info + modelProfileList.add(new ModelProfileOnNode(curNodeId, new ModelProfile(entry.getKey(), null, entry.getValue()))); } } @@ -117,14 +130,14 @@ public ProfileResponse(ClusterName clusterName, List nodes, if (coordinatingNode == null) { coordinatingNode = ""; } - this.modelProfile = modelProfileList.toArray(new ModelProfile[0]); + this.modelProfile = modelProfileList.toArray(new ModelProfileOnNode[0]); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeVInt(modelProfile.length); - for (ModelProfile profile : modelProfile) { + for (ModelProfileOnNode profile : modelProfile) { profile.writeTo(out); } out.writeInt(shingleSize); @@ -152,14 +165,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ACTIVE_ENTITY, activeEntities); builder.field(TOTAL_UPDATES, totalUpdates); builder.startArray(MODELS); - for (ModelProfile profile : modelProfile) { + for (ModelProfileOnNode profile : modelProfile) { profile.toXContent(builder, params); } builder.endArray(); return builder; } - public ModelProfile[] getModelProfile() { + public ModelProfileOnNode[] getModelProfile() { return modelProfile; } diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java index 0495552b6..7668f18fa 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java @@ -38,6 +38,7 @@ import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; @@ -113,6 +114,7 @@ protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { long activeEntity = 0; long totalUpdates = 0; Map modelSize = null; + List modelProfiles = null; if (request.isForMultiEntityDetector()) { if (profiles.contains(DetectorProfileName.ACTIVE_ENTITIES)) { activeEntity = cacheProvider.get().getActiveEntities(detectorId); @@ -120,9 +122,13 @@ protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { if (profiles.contains(DetectorProfileName.INIT_PROGRESS)) { totalUpdates = cacheProvider.get().getTotalUpdates(detectorId); } - if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(DetectorProfileName.MODELS)) { + if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { modelSize = cacheProvider.get().getModelSize(detectorId); } + // need to provide entity info for HCAD + if (profiles.contains(DetectorProfileName.MODELS)) { + modelProfiles = cacheProvider.get().getAllModelProfile(detectorId); + } } else { if (profiles.contains(DetectorProfileName.COORDINATING_NODE) || profiles.contains(DetectorProfileName.SHINGLE_SIZE)) { shingleSize = featureManager.getShingleSize(detectorId); @@ -133,6 +139,6 @@ protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { } } - return new ProfileNodeResponse(clusterService.localNode(), modelSize, shingleSize, activeEntity, totalUpdates); + return new ProfileNodeResponse(clusterService.localNode(), modelSize, shingleSize, activeEntity, totalUpdates, modelProfiles); } } diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java index 92de00a3d..25ccfa0fc 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java @@ -33,7 +33,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -75,7 +75,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); + builder.field(CommonName.ID_JSON_KEY, adID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java index 8c63d0bb3..4a88c3cfe 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java @@ -33,7 +33,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -107,9 +107,9 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); - builder.field(CommonMessageAttributes.MODEL_ID_JSON_KEY, modelID); - builder.startArray(CommonMessageAttributes.FEATURE_JSON_KEY); + builder.field(CommonName.ID_JSON_KEY, adID); + builder.field(CommonName.MODEL_ID_KEY, modelID); + builder.startArray(CommonName.FEATURE_JSON_KEY); for (double feature : features) { builder.value(feature); } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java index c02da2ef6..22ff7e23b 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java @@ -35,7 +35,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.InputStreamStreamInput; import org.opensearch.common.io.stream.OutputStreamStreamOutput; @@ -87,7 +87,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); + builder.field(CommonName.ID_JSON_KEY, adID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java index b824475b1..d020b226c 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java @@ -33,7 +33,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -99,9 +99,9 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ID_JSON_KEY, adID); - builder.field(CommonMessageAttributes.MODEL_ID_JSON_KEY, modelID); - builder.field(CommonMessageAttributes.RCF_SCORE_JSON_KEY, rcfScore); + builder.field(CommonName.ID_JSON_KEY, adID); + builder.field(CommonName.MODEL_ID_KEY, modelID); + builder.field(CommonName.RCF_SCORE_JSON_KEY, rcfScore); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java index e877c0fa3..e59f37989 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java @@ -29,7 +29,7 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.ToXContentObject; @@ -67,8 +67,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonMessageAttributes.ANOMALY_GRADE_JSON_KEY, anomalyGrade); - builder.field(CommonMessageAttributes.CONFIDENCE_JSON_KEY, confidence); + builder.field(CommonName.ANOMALY_GRADE_JSON_KEY, anomalyGrade); + builder.field(CommonName.CONFIDENCE_JSON_KEY, confidence); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java b/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java index 8b6198ae5..86ef797d0 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java @@ -44,6 +44,7 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.BulkUtil; import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.IndexUtils; import org.opensearch.ad.util.RestHandlerUtils; @@ -212,12 +213,12 @@ void saveIteration(IndexRequest indexRequest, String detectorId, Iterator saveIteration(newReuqest, detectorId, backoff), nextDelay, ThreadPool.Names.SAME); + threadPool + .schedule( + () -> saveIteration(BulkUtil.cloneIndexRequest(indexRequest), detectorId, backoff), + nextDelay, + ThreadPool.Names.SAME + ); } } ) diff --git a/src/main/java/org/opensearch/ad/transport/handler/DetectionStateHandler.java b/src/main/java/org/opensearch/ad/transport/handler/DetectionStateHandler.java index 766b47198..6115745ed 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/DetectionStateHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/DetectionStateHandler.java @@ -167,6 +167,7 @@ private void update(String detectorId, GetStateStrategy handler) { if (cause instanceof IndexNotFoundException) { super.index(handler.createNewState(null), detectorId); } else { + // e.g., can happen during node reboot LOG.error("Failed to get detector state " + detectorId, exception); } })); diff --git a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java index 22523c9bb..7f7981ba8 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java @@ -26,25 +26,18 @@ package org.opensearch.ad.transport.handler; -import java.time.Clock; -import java.util.Locale; -import java.util.concurrent.RejectedExecutionException; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.core.util.Throwables; import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.ActionListener; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.bulk.BulkResponse; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.transport.ADResultBulkAction; import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.IndexUtils; import org.opensearch.ad.util.ThrowingConsumerWrapper; @@ -66,8 +59,8 @@ */ public class MultiEntityResultHandler extends AnomalyIndexHandler { private static final Logger LOG = LogManager.getLogger(MultiEntityResultHandler.class); - private final NodeStateManager nodeStateManager; - private final Clock clock; + private static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; + private static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; @Inject public MultiEntityResultHandler( @@ -77,32 +70,7 @@ public MultiEntityResultHandler( AnomalyDetectionIndices anomalyDetectionIndices, ClientUtil clientUtil, IndexUtils indexUtils, - ClusterService clusterService, - NodeStateManager nodeStateManager - ) { - this( - client, - settings, - threadPool, - anomalyDetectionIndices, - clientUtil, - indexUtils, - clusterService, - nodeStateManager, - Clock.systemUTC() - ); - } - - protected MultiEntityResultHandler( - Client client, - Settings settings, - ThreadPool threadPool, - AnomalyDetectionIndices anomalyDetectionIndices, - ClientUtil clientUtil, - IndexUtils indexUtils, - ClusterService clusterService, - NodeStateManager nodeStateManager, - Clock clock + ClusterService clusterService ) { super( client, @@ -115,77 +83,57 @@ protected MultiEntityResultHandler( indexUtils, clusterService ); - this.nodeStateManager = nodeStateManager; - this.clock = clock; } /** * Execute the bulk request * @param currentBulkRequest The bulk request - * @param detectorId Detector Id + * @param listener callback after flushing */ - public void flush(ADResultBulkRequest currentBulkRequest, String detectorId) { + public void flush(ADResultBulkRequest currentBulkRequest, ActionListener listener) { if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { - LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); + listener.onFailure(new AnomalyDetectionException(CANNOT_SAVE_RESULT_ERR_MSG)); return; } try { if (!indexExists.getAsBoolean()) { - createIndex - .accept( - ActionListener - .wrap(initResponse -> onCreateIndexResponse(initResponse, currentBulkRequest, detectorId), exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - bulk(currentBulkRequest, detectorId); - } else { - throw new AnomalyDetectionException( - detectorId, - String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), - exception - ); - } - }) - ); + createIndex.accept(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Creating result index with mappings call not acknowledged."); + listener.onFailure(new AnomalyDetectionException("", "Creating result index with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Unexpected error creating result index", exception); + listener.onFailure(exception); + } + })); } else { - bulk(currentBulkRequest, detectorId); + bulk(currentBulkRequest, listener); } } catch (Exception e) { - throw new AnomalyDetectionException( - detectorId, - String.format(Locale.ROOT, "Error in bulking %s for detector %s", indexName, detectorId), - e - ); - } - } - - private void onCreateIndexResponse(CreateIndexResponse response, ADResultBulkRequest bulkRequest, String detectorId) { - if (response.isAcknowledged()) { - bulk(bulkRequest, detectorId); - } else { - throw new AnomalyDetectionException(detectorId, "Creating %s with mappings call not acknowledged."); + LOG.warn("Error in bulking results", e); + listener.onFailure(e); } } - private void bulk(ADResultBulkRequest currentBulkRequest, String detectorId) { + private void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new AnomalyDetectionException("no result to save")); return; } - client - .execute( - ADResultBulkAction.INSTANCE, - currentBulkRequest, - ActionListener - .wrap(response -> LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, detectorId)), exception -> { - LOG.error(String.format(Locale.ROOT, FAIL_TO_SAVE_ERR_MSG, detectorId), exception); - Throwable cause = Throwables.getRootCause(exception); - // too much indexing pressure - // TODO: pause indexing a bit before trying again, ideally with randomized exponential backoff. - if (cause instanceof RejectedExecutionException) { - nodeStateManager.setLastIndexThrottledTime(clock.instant()); - } - }) - ); + client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); } } diff --git a/src/main/java/org/opensearch/ad/util/BulkUtil.java b/src/main/java/org/opensearch/ad/util/BulkUtil.java index 60c7ab3b2..167243250 100644 --- a/src/main/java/org/opensearch/ad/util/BulkUtil.java +++ b/src/main/java/org/opensearch/ad/util/BulkUtil.java @@ -37,25 +37,73 @@ import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; public class BulkUtil { private static final Logger logger = LogManager.getLogger(BulkUtil.class); - public static List> getIndexRequestToRetry(BulkRequest bulkRequest, BulkResponse bulkResponse) { - List> res = new ArrayList<>(); + public static List getIndexRequestToRetry(BulkRequest bulkRequest, BulkResponse bulkResponse) { + List res = new ArrayList<>(); Set failedId = new HashSet<>(); for (BulkItemResponse response : bulkResponse.getItems()) { - if (response.isFailed()) { + if (response.isFailed() && ExceptionUtil.isRetryAble(response.getFailure().getStatus())) { failedId.add(response.getId()); } } for (DocWriteRequest request : bulkRequest.requests()) { - if (failedId.contains(request.id())) { - res.add(request); + try { + if (failedId.contains(request.id())) { + res.add(cloneIndexRequest((IndexRequest) request)); + } + } catch (ClassCastException e) { + logger.error("We only support IndexRequest", e); + throw e; } + } return res; } + + public static List getFailedIndexRequest(BulkRequest bulkRequest, BulkResponse bulkResponse) { + List res = new ArrayList<>(); + + if (bulkResponse == null || bulkRequest == null) { + return res; + } + + Set failedId = new HashSet<>(); + for (BulkItemResponse response : bulkResponse.getItems()) { + if (response.isFailed() && ExceptionUtil.isRetryAble(response.getFailure().getStatus())) { + failedId.add(response.getId()); + } + } + + for (DocWriteRequest request : bulkRequest.requests()) { + try { + if (failedId.contains(request.id())) { + res.add((IndexRequest) request); + } + } catch (ClassCastException e) { + logger.error("We only support IndexRequest"); + throw e; + } + + } + return res; + } + + /** + * Copy original request's source without other information like autoGeneratedTimestamp. + * otherwise, an exception will be thrown indicating autoGeneratedTimestamp should not be set + * while request id is already set (id is set because we have already sent the request before). + * @param indexRequest request to be cloned + * @return cloned Request + */ + public static IndexRequest cloneIndexRequest(IndexRequest indexRequest) { + IndexRequest newRequest = new IndexRequest(indexRequest.index()); + newRequest.source(indexRequest.source(), indexRequest.getContentType()); + return newRequest; + } } diff --git a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java index 89b4503a9..fb0058f0c 100644 --- a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java +++ b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java @@ -26,18 +26,35 @@ package org.opensearch.ad.util; +import java.util.EnumSet; +import java.util.concurrent.RejectedExecutionException; + import org.apache.commons.lang.exception.ExceptionUtils; +import org.apache.logging.log4j.core.util.Throwables; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.NoShardAvailableActionException; +import org.opensearch.action.UnavailableShardsException; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.replication.ReplicationResponse; import org.opensearch.ad.common.exception.AnomalyDetectionException; +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.common.exception.ResourceNotFoundException; import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.rest.RestStatus; public class ExceptionUtil { public static final String RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE = OpenSearchException .getExceptionName(new ResourceNotFoundException("", "")); + // a positive cache of retriable error rest status + private static final EnumSet RETRYABLE_STATUS = EnumSet + .of(RestStatus.REQUEST_TIMEOUT, RestStatus.CONFLICT, RestStatus.INTERNAL_SERVER_ERROR); + /** * OpenSearch restricts the kind of exceptions can be thrown over the wire * (See OpenSearchException.OpenSearchExceptionHandle). Since we cannot @@ -116,4 +133,80 @@ public static String getErrorMessage(Exception e) { return ExceptionUtils.getFullStackTrace(e); } } + + /** + * + * @param exception Exception + * @return whether the cause indicates the cluster is overloaded + */ + public static boolean isOverloaded(Throwable exception) { + Throwable cause = Throwables.getRootCause(exception); + // LimitExceededException may indicate circuit breaker exception + // UnavailableShardsException can happen when the system cannot respond + // to requests + return cause instanceof RejectedExecutionException + || cause instanceof OpenSearchRejectedExecutionException + || cause instanceof UnavailableShardsException + || cause instanceof LimitExceededException; + } + + public static boolean isRetryAble(Exception e) { + Throwable cause = ExceptionsHelper.unwrapCause(e); + RestStatus status = ExceptionsHelper.status(cause); + return isRetryAble(status); + } + + public static boolean isRetryAble(RestStatus status) { + return RETRYABLE_STATUS.contains(status); + } + + /** + * Wrap a listener to return the given exception no matter what + * @param The type of listener response + * @param original Original listener + * @param exceptionToReturn The exception to return + * @param detectorId Detector Id + * @return the wrapped listener + */ + public static ActionListener wrapListener(ActionListener original, Exception exceptionToReturn, String detectorId) { + return ActionListener + .wrap( + r -> { original.onFailure(exceptionToReturn); }, + e -> { original.onFailure(selectHigherPriorityException(exceptionToReturn, e)); } + ); + } + + /** + * Return an exception that has higher priority. + * If an exception is EndRunException while another one is not, the former has + * higher priority. + * If both exceptions are EndRunException, the one with end now true has higher + * priority. + * Otherwise, return the second given exception. + * @param exception1 Exception 1 + * @param exception2 Exception 2 + * @return high priority exception + */ + public static Exception selectHigherPriorityException(Exception exception1, Exception exception2) { + if (exception1 instanceof EndRunException) { + // we have already had EndRunException. Don't replace it with something less severe + EndRunException endRunException = (EndRunException) exception1; + if (endRunException.isEndNow()) { + // don't proceed if recorded exception is ending now + return exception1; + } + if (false == (exception2 instanceof EndRunException) || false == ((EndRunException) exception2).isEndNow()) { + // don't proceed if the giving exception is not ending now + return exception1; + } + } + return exception2; + } + + public static boolean isIndexNotAvailable(Exception e) { + if (e == null) { + return false; + } + return e instanceof IndexNotFoundException || e instanceof NoShardAvailableActionException; + } } diff --git a/src/main/java/org/opensearch/ad/util/ParseUtils.java b/src/main/java/org/opensearch/ad/util/ParseUtils.java index ab3454f0b..ad621e7cf 100644 --- a/src/main/java/org/opensearch/ad/util/ParseUtils.java +++ b/src/main/java/org/opensearch/ad/util/ParseUtils.java @@ -56,6 +56,7 @@ import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.Feature; import org.opensearch.ad.model.FeatureData; import org.opensearch.ad.model.IntervalTimeConfiguration; @@ -392,12 +393,15 @@ public static String generateInternalFeatureQueryTemplate(AnomalyDetector detect public static SearchSourceBuilder generateEntityColdStartQuery( AnomalyDetector detector, List> ranges, - String entityName, + Entity entity, NamedXContentRegistry xContentRegistry ) throws IOException { - TermQueryBuilder term = new TermQueryBuilder(detector.getCategoryField().get(0), entityName); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()).filter(term); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); for (Entry range : ranges) { diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 9fcaab831..d26bdb336 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "detector_id": { @@ -103,6 +103,9 @@ }, "task_id": { "type": "keyword" + }, + "model_id": { + "type": "keyword" } } } diff --git a/src/main/resources/mappings/checkpoint.json b/src/main/resources/mappings/checkpoint.json index e058ec3d4..a413fc6fa 100644 --- a/src/main/resources/mappings/checkpoint.json +++ b/src/main/resources/mappings/checkpoint.json @@ -1,7 +1,7 @@ { "dynamic": true, "_meta": { - "schema_version": 2 + "schema_version": 3 }, "properties": { "detectorId": { @@ -21,6 +21,17 @@ }, "schema_version": { "type": "integer" + }, + "entity": { + "type": "nested", + "properties": { + "name": { + "type": "keyword" + }, + "value": { + "type": "keyword" + } + } } } } diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index d50d8354b..2aaba3e27 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -195,10 +195,11 @@ public void setUp() throws Exception { ); } - public void testTwoCategoricalFields() throws IOException { + // we support upto 2 category fields now + public void testThreeCategoricalFields() throws IOException { expectThrows( IllegalArgumentException.class, - () -> TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a", "b")) + () -> TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a", "b", "c")) ); } diff --git a/src/test/java/org/opensearch/ad/AbstractADTest.java b/src/test/java/org/opensearch/ad/AbstractADTest.java index 22fe58dc9..7b8b86501 100644 --- a/src/test/java/org/opensearch/ad/AbstractADTest.java +++ b/src/test/java/org/opensearch/ad/AbstractADTest.java @@ -36,8 +36,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; @@ -57,9 +59,11 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.model.Entity; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.http.HttpRequest; @@ -212,7 +216,7 @@ protected static void setUpThreadPool(String name) { AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, 1, 1000, - "opendistro.ad." + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + "opensearch.ad." + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME ) ); } @@ -223,17 +227,33 @@ protected static void tearDownThreadPool() { threadPool = null; } - public void setupTestNodes(Settings settings, TransportInterceptor transportInterceptor) { + /** + * + * @param transportInterceptor Interceptor to for transport requests. Used + * to mock transport layer. + * @param nodeSettings node override of setting + * @param setting the supported setting set. + */ + public void setupTestNodes(TransportInterceptor transportInterceptor, final Settings nodeSettings, Setting... setting) { nodesCount = randomIntBetween(2, 10); testNodes = new FakeNode[nodesCount]; + Set> settingSet = new HashSet<>(Arrays.asList(setting)); for (int i = 0; i < testNodes.length; i++) { - testNodes[i] = new FakeNode("node" + i, threadPool, settings, transportInterceptor); + testNodes[i] = new FakeNode("node" + i, threadPool, nodeSettings, settingSet, transportInterceptor); } FakeNode.connectNodes(testNodes); } - public void setupTestNodes(Settings settings) { - setupTestNodes(settings, TransportService.NOOP_TRANSPORT_INTERCEPTOR); + public void setupTestNodes(Setting... setting) { + setupTestNodes(TransportService.NOOP_TRANSPORT_INTERCEPTOR, Settings.EMPTY, setting); + } + + public void setupTestNodes(Settings nodeSettings) { + setupTestNodes(TransportService.NOOP_TRANSPORT_INTERCEPTOR, nodeSettings); + } + + public void setupTestNodes(TransportInterceptor transportInterceptor) { + setupTestNodes(transportInterceptor, Settings.EMPTY); } public void tearDownTestNodes() { @@ -340,7 +360,7 @@ public HttpRequest releaseAndCopy() { }, null); } - protected boolean areEqualWithArrayValue(Map first, Map second) { + protected boolean areEqualWithArrayValue(Map first, Map second) { if (first.size() != second.size()) { return false; } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index ba687f2d3..79ee66a72 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -63,7 +64,7 @@ import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; @@ -404,8 +405,22 @@ private void setUpClientExecuteProfileAction() { } }; - ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize, 0L, 0L); - ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, -1, 0L, 0L); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0L, + 0L, + new ArrayList<>() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + -1, + 0L, + 0L, + new ArrayList<>() + ); List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); List failures = Collections.emptyList(); ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); @@ -486,7 +501,7 @@ public void testProfileModels() throws InterruptedException, IOException { assertEquals(shingleSize, profileResponse.getShingleSize()); assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); assertEquals(2, profileResponse.getModelProfile().length); - for (ModelProfile profile : profileResponse.getModelProfile()) { + for (ModelProfileOnNode profile : profileResponse.getModelProfile()) { assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); assertEquals(modelSize, profile.getModelSize()); if (node1.equals(profile.getNodeId())) { diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index 0e74400e8..ff463c48e 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -33,8 +33,6 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; -import java.io.IOException; -import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -53,16 +51,14 @@ import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.model.EntityProfile; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.model.EntityState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.client.Client; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -90,6 +86,7 @@ public class EntityProfileRunnerTests extends AbstractADTest { private String modelId; private long modelSize; private String nodeId; + private Entity entity; enum InittedEverResultStatus { UNKNOWN, @@ -143,6 +140,8 @@ public void setUp() throws Exception { return null; }).when(client).get(any(), any()); + + entity = Entity.createSingleAttributeEntity(detectorId, categoryField, entityValue); } @SuppressWarnings("unchecked") @@ -182,7 +181,7 @@ private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { smallUpdates = 1; latestActiveTimestamp = 1603999189758L; isActive = Boolean.TRUE; - modelId = "T4c3dXUBj-2IZN7itix__entity_app_6"; + modelId = "T4c3dXUBj-2IZN7itix__entity_" + entityValue; modelSize = 712480L; nodeId = "g6pmr547QR-CfpEvO67M4g"; doAnswer(invocation -> { @@ -198,7 +197,7 @@ private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { profileResponseBuilder.setActive(isActive); } else { profileResponseBuilder.setTotalUpdates(requiredSamples + 1); - ModelProfile model = new ModelProfile(modelId, modelSize, nodeId); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); profileResponseBuilder.setModelProfile(model); } @@ -208,42 +207,12 @@ private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { }).when(client).execute(any(EntityProfileAction.class), any(), any()); } - @SuppressWarnings("unchecked") - public void testNotMultiEntityDetector() throws IOException, InterruptedException { - detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); - - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - GetRequest request = (GetRequest) args[0]; - ActionListener listener = (ActionListener) args[1]; - - String indexName = request.index(); - if (indexName.equals(ANOMALY_DETECTORS_INDEX)) { - listener - .onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - } - - return null; - }).when(client).get(any(), any()); - - final CountDownLatch inProgressLatch = new CountDownLatch(1); - - runner.profile(detectorId, entityValue, state, ActionListener.wrap(response -> { - assertTrue("Should not reach here", false); - inProgressLatch.countDown(); - }, exception -> { - assertTrue(exception.getMessage().contains(EntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); - inProgressLatch.countDown(); - })); - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - } - public void stateTestTemplate(InittedEverResultStatus returnedState, EntityState expectedState) throws InterruptedException { setUpExecuteEntityProfileAction(returnedState); final CountDownLatch inProgressLatch = new CountDownLatch(1); - runner.profile(detectorId, entityValue, state, ActionListener.wrap(response -> { + runner.profile(detectorId, entity, state, ActionListener.wrap(response -> { assertEquals(expectedState, response.getState()); inProgressLatch.countDown(); }, exception -> { @@ -253,53 +222,10 @@ public void stateTestTemplate(InittedEverResultStatus returnedState, EntityState assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); } - public void testUnknownState() throws InterruptedException { - stateTestTemplate(InittedEverResultStatus.UNKNOWN, EntityState.UNKNOWN); - } - - public void testInitState() throws InterruptedException { - stateTestTemplate(InittedEverResultStatus.NOT_INITTED, EntityState.INIT); - } - - public void testRunningState() throws InterruptedException { - stateTestTemplate(InittedEverResultStatus.INITTED, EntityState.RUNNING); - } - - public void testInitNInfo() throws InterruptedException { - setUpExecuteEntityProfileAction(InittedEverResultStatus.NOT_INITTED); - setUpSearch(); - - EntityProfile.Builder expectedProfile = new EntityProfile.Builder(categoryField, entityValue); - - // 1 / 128 rounded to 1% - int neededSamples = requiredSamples - smallUpdates; - InitProgressProfile profile = new InitProgressProfile( - "1%", - neededSamples * detector.getDetectorIntervalInSeconds() / 60, - neededSamples - ); - expectedProfile.initProgress(profile); - expectedProfile.isActive(isActive); - expectedProfile.lastActiveTimestampMs(latestActiveTimestamp); - expectedProfile.lastSampleTimestampMs(latestSampleTimestamp); - - final CountDownLatch inProgressLatch = new CountDownLatch(1); - - runner.profile(detectorId, entityValue, initNInfo, ActionListener.wrap(response -> { - assertEquals(expectedProfile.build(), response); - inProgressLatch.countDown(); - }, exception -> { - LOG.error("Unexpected error", exception); - assertTrue("Should not reach here", false); - inProgressLatch.countDown(); - })); - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - } - public void testEmptyProfile() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - runner.profile(detectorId, entityValue, new HashSet<>(), ActionListener.wrap(response -> { + runner.profile(detectorId, entity, new HashSet<>(), ActionListener.wrap(response -> { assertTrue("Should not reach here", false); inProgressLatch.countDown(); }, exception -> { @@ -308,58 +234,4 @@ public void testEmptyProfile() throws InterruptedException { })); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); } - - public void testModel() throws InterruptedException { - setUpExecuteEntityProfileAction(InittedEverResultStatus.INITTED); - - EntityProfile.Builder expectedProfile = new EntityProfile.Builder(categoryField, entityValue); - ModelProfile modelProfile = new ModelProfile(modelId, modelSize, nodeId); - expectedProfile.modelProfile(modelProfile); - - final CountDownLatch inProgressLatch = new CountDownLatch(1); - - runner.profile(detectorId, entityValue, model, ActionListener.wrap(response -> { - assertEquals(expectedProfile.build(), response); - inProgressLatch.countDown(); - }, exception -> { - assertTrue("Should not reach here", false); - inProgressLatch.countDown(); - })); - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - } - - @SuppressWarnings("unchecked") - public void testJobIndexNotFound() throws InterruptedException { - setUpExecuteEntityProfileAction(InittedEverResultStatus.INITTED); - - final CountDownLatch inProgressLatch = new CountDownLatch(1); - - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - GetRequest request = (GetRequest) args[0]; - ActionListener listener = (ActionListener) args[1]; - - String indexName = request.index(); - if (indexName.equals(ANOMALY_DETECTORS_INDEX)) { - listener - .onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - } else if (indexName.equals(ANOMALY_DETECTOR_JOB_INDEX)) { - listener.onFailure(new IndexNotFoundException(ANOMALY_DETECTOR_JOB_INDEX)); - } - - return null; - }).when(client).get(any(), any()); - - EntityProfile expectedProfile = new EntityProfile.Builder(categoryField, entityValue).build(); - - runner.profile(detectorId, entityValue, initNInfo, ActionListener.wrap(response -> { - assertEquals(expectedProfile, response); - inProgressLatch.countDown(); - }, exception -> { - LOG.error("Unexpected error", exception); - assertTrue("Should not reach here", false); - inProgressLatch.countDown(); - })); - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); - } } diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java index 07f6c275e..61d085136 100644 --- a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java @@ -33,6 +33,7 @@ import java.util.Collections; import java.util.HashSet; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -67,6 +68,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase { double modelDesiredSizePercentage; JvmService jvmService; AnomalyDetector detector; + ADCircuitBreakerService circuitBreaker; @Override public void setUp() throws Exception { @@ -115,18 +117,35 @@ public void setUp() throws Exception { detector = mock(AnomalyDetector.class); when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a")); when(detector.getShingleSize()).thenReturn(1); + + circuitBreaker = mock(ADCircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); } private void setUpBigHeap() { ByteSizeValue value = new ByteSizeValue(largeHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } private void setUpSmallHeap() { ByteSizeValue value = new ByteSizeValue(smallHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } public void testEstimateModelSize() { @@ -145,10 +164,10 @@ public void testCanAllocate() { assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10))); long bytesToUse = 100_000; - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); - tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); } @@ -162,11 +181,11 @@ public void testMemoryToShed() { long bytesToUse = 100_000; assertEquals(bytesToUse, tracker.getHeapLimit()); assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize()); - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); - tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR); assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes()); assertEquals(bytesToUse, tracker.memoryToShed()); - assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, 2 * bytesToUse, bytesToUse)); + assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); } } diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index a3b10e71d..99618e6e2 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -35,6 +35,7 @@ import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -200,8 +201,22 @@ private void setUpClientExecuteProfileAction(InittedEverResultStatus initted) { if (InittedEverResultStatus.INITTED == initted) { updates = requiredSamples + 1; } - ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize, 1L, updates); - ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, shingleSize, 1L, updates); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 1L, + updates, + new ArrayList<>() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + shingleSize, + 1L, + updates, + new ArrayList<>() + ); List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); List failures = Collections.emptyList(); ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); diff --git a/src/test/java/org/opensearch/ad/NodeStateTests.java b/src/test/java/org/opensearch/ad/NodeStateTests.java index 6b3c3d6d8..a2af98d2c 100644 --- a/src/test/java/org/opensearch/ad/NodeStateTests.java +++ b/src/test/java/org/opensearch/ad/NodeStateTests.java @@ -104,14 +104,14 @@ public void testMaintenancFlagRemove() throws IOException { public void testMaintenanceLastColdStartRemoved() { when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); - state.setLastColdStartException(new AnomalyDetectionException("123", "")); + state.setException(new AnomalyDetectionException("123", "")); when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); assertTrue(state.expired(duration)); } public void testMaintenanceLastColdStartNotRemoved() { when(clock.instant()).thenReturn(Instant.ofEpochMilli(1_000_000L)); - state.setLastColdStartException(new AnomalyDetectionException("123", "")); + state.setException(new AnomalyDetectionException("123", "")); when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); assertTrue(!state.expired(duration)); } diff --git a/src/test/java/org/opensearch/ad/TestHelpers.java b/src/test/java/org/opensearch/ad/TestHelpers.java index 4addb8f3b..291e9e032 100644 --- a/src/test/java/org/opensearch/ad/TestHelpers.java +++ b/src/test/java/org/opensearch/ad/TestHelpers.java @@ -639,15 +639,16 @@ public static AnomalyResult randomAnomalyDetectResult(double score, String error error, null, user, - CommonValue.NO_SCHEMA_VERSION + CommonValue.NO_SCHEMA_VERSION, + null ); } - public static AnomalyResult randomMultiEntityAnomalyDetectResult(double score, double grade) { - return randomMutlEntityAnomalyDetectResult(score, grade, null); + public static AnomalyResult randomHCADAnomalyDetectResult(double score, double grade) { + return randomHCADAnomalyDetectResult(score, grade, null); } - public static AnomalyResult randomMutlEntityAnomalyDetectResult(double score, double grade, String error) { + public static AnomalyResult randomHCADAnomalyDetectResult(double score, double grade, String error) { return new AnomalyResult( randomAlphaOfLength(5), score, @@ -659,7 +660,7 @@ public static AnomalyResult randomMutlEntityAnomalyDetectResult(double score, do Instant.now().truncatedTo(ChronoUnit.SECONDS), Instant.now().truncatedTo(ChronoUnit.SECONDS), error, - Arrays.asList(new Entity(randomAlphaOfLength(5), randomAlphaOfLength(5))), + Entity.createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5)), randomUser(), CommonValue.NO_SCHEMA_VERSION ); diff --git a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java new file mode 100644 index 000000000..2a848c772 --- /dev/null +++ b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Random; + +import org.junit.Before; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; + +public class AbstractCacheTest extends AbstractADTest { + protected String modelId1, modelId2, modelId3, modelId4; + protected Entity entity1, entity2, entity3, entity4; + protected ModelState modelState1, modelState2, modelState3, modelState4; + protected String detectorId; + protected AnomalyDetector detector; + protected Clock clock; + protected Duration detectorDuration; + protected float initialPriority; + protected CacheBuffer cacheBuffer; + protected long memoryPerEntity; + protected MemoryTracker memoryTracker; + protected CheckpointWriteWorker checkpointWriteQueue; + protected Random random; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + detector = mock(AnomalyDetector.class); + detectorId = "123"; + when(detector.getDetectorId()).thenReturn(detectorId); + detectorDuration = Duration.ofMinutes(5); + when(detector.getDetectionIntervalDuration()).thenReturn(detectorDuration); + when(detector.getDetectorIntervalInSeconds()).thenReturn(detectorDuration.getSeconds()); + + entity1 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal1"); + entity2 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal2"); + entity3 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal3"); + entity4 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal4"); + modelId1 = entity1.getModelId(detectorId).get(); + modelId2 = entity2.getModelId(detectorId).get(); + modelId3 = entity3.getModelId(detectorId).get(); + modelId4 = entity4.getModelId(detectorId).get(); + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + memoryPerEntity = 81920; + memoryTracker = mock(MemoryTracker.class); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + cacheBuffer = new CacheBuffer( + 1, + 1, + memoryPerEntity, + memoryTracker, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + detectorId, + checkpointWriteQueue, + new Random(42) + ); + + initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); + + modelState1 = new ModelState<>( + new EntityModel(entity1, new ArrayDeque<>(), null, null), + modelId1, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState2 = new ModelState<>( + new EntityModel(entity2, new ArrayDeque<>(), null, null), + modelId2, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState3 = new ModelState<>( + new EntityModel(entity3, new ArrayDeque<>(), null, null), + modelId3, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState4 = new ModelState<>( + new EntityModel(entity4, new ArrayDeque<>(), null, null), + modelId4, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + } +} diff --git a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java index 4276055ba..3d871f54a 100644 --- a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java @@ -26,60 +26,23 @@ package org.opensearch.ad.caching; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.List; import java.util.Map.Entry; +import java.util.Optional; -import org.junit.Before; import org.mockito.ArgumentCaptor; import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.test.OpenSearchTestCase; import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; -public class CacheBufferTests extends OpenSearchTestCase { - CacheBuffer cacheBuffer; - CheckpointDao checkpointDao; - MemoryTracker memoryTracker; - Clock clock; - String detectorId; - float initialPriority; - long memoryPerEntity; - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - - checkpointDao = mock(CheckpointDao.class); - memoryTracker = mock(MemoryTracker.class); - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); - - detectorId = "123"; - memoryPerEntity = 81920; - - cacheBuffer = new CacheBuffer( - 1, - 1, - checkpointDao, - memoryPerEntity, - memoryTracker, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - detectorId - ); - initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); - } +public class CacheBufferTests extends AbstractCacheTest { // cache.put(1, 1); // cache.put(2, 2); @@ -92,25 +55,20 @@ public void setUp() throws Exception { // cache.get(3); // returns 3 // cache.get(4); // returns 4 public void testRemovalCandidate() { - String modelId1 = "1"; - String modelId2 = "2"; - String modelId3 = "3"; - String modelId4 = "4"; - - cacheBuffer.put(modelId1, MLUtil.randomModelState(initialPriority, modelId1)); - cacheBuffer.put(modelId2, MLUtil.randomModelState(initialPriority, modelId2)); + cacheBuffer.put(modelId1, modelState1); + cacheBuffer.put(modelId2, modelState2); assertEquals(modelId1, cacheBuffer.get(modelId1).getModelId()); - Entry removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority(); - assertEquals(modelId2, removalCandidate.getKey()); + Optional> removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority(); + assertEquals(modelId2, removalCandidate.get().getKey()); cacheBuffer.remove(); - cacheBuffer.put(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); + cacheBuffer.put(modelId3, modelState3); assertEquals(null, cacheBuffer.get(modelId2)); assertEquals(modelId3, cacheBuffer.get(modelId3).getModelId()); removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority(); - assertEquals(modelId1, removalCandidate.getKey()); + assertEquals(modelId1, removalCandidate.get().getKey()); cacheBuffer.remove(modelId1); assertEquals(null, cacheBuffer.get(modelId1)); - cacheBuffer.put(modelId4, MLUtil.randomModelState(initialPriority, modelId4)); + cacheBuffer.put(modelId4, modelState4); assertEquals(modelId3, cacheBuffer.get(modelId3).getModelId()); assertEquals(modelId4, cacheBuffer.get(modelId4).getModelId()); } @@ -121,14 +79,10 @@ public void testRemovalCandidate() { // cache.put(4, 4); // cache.get(2) => returns 2 public void testRemovalCandidate2() throws InterruptedException { - String modelId2 = "2"; - String modelId3 = "3"; - String modelId4 = "4"; - float initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); - cacheBuffer.put(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); - cacheBuffer.put(modelId2, MLUtil.randomModelState(initialPriority, modelId2)); - cacheBuffer.put(modelId2, MLUtil.randomModelState(initialPriority, modelId2)); - cacheBuffer.put(modelId4, MLUtil.randomModelState(initialPriority, modelId4)); + cacheBuffer.put(modelId3, modelState3); + cacheBuffer.put(modelId2, modelState2); + cacheBuffer.put(modelId2, modelState2); + cacheBuffer.put(modelId4, modelState4); assertTrue(cacheBuffer.getModel(modelId2).isPresent()); ArgumentCaptor memoryReleased = ArgumentCaptor.forClass(Long.class); @@ -143,7 +97,7 @@ public void testRemovalCandidate2() throws InterruptedException { assertEquals(3 * memoryPerEntity, capturedMemoryReleased.stream().reduce(0L, (a, b) -> a + b).intValue()); assertTrue(capturedreserved.get(0)); assertTrue(!capturedreserved.get(1)); - assertEquals(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, capturedOrigin.get(0)); + assertEquals(MemoryTracker.Origin.HC_DETECTOR, capturedOrigin.get(0)); assertTrue(!cacheBuffer.expired(Duration.ofHours(1))); } @@ -155,13 +109,13 @@ public void testCanRemove() { assertTrue(cacheBuffer.dedicatedCacheAvailable()); assertTrue(!cacheBuffer.canReplaceWithinDetector(100)); - cacheBuffer.put(modelId1, MLUtil.randomModelState(initialPriority, modelId1)); + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); assertTrue(cacheBuffer.canReplaceWithinDetector(100)); assertTrue(!cacheBuffer.dedicatedCacheAvailable()); assertTrue(!cacheBuffer.canRemove()); - cacheBuffer.put(modelId2, MLUtil.randomModelState(initialPriority, modelId2)); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); assertTrue(cacheBuffer.canRemove()); - cacheBuffer.replace(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); + cacheBuffer.replace(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); assertTrue(cacheBuffer.isActive(modelId2)); assertTrue(cacheBuffer.isActive(modelId3)); assertEquals(modelId3, cacheBuffer.getPriorityTracker().getHighestPriorityEntityId().get()); @@ -172,9 +126,9 @@ public void testMaintenance() { String modelId1 = "1"; String modelId2 = "2"; String modelId3 = "3"; - cacheBuffer.put(modelId1, MLUtil.randomModelState(initialPriority, modelId1)); - cacheBuffer.put(modelId2, MLUtil.randomModelState(initialPriority, modelId2)); - cacheBuffer.put(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); assertEquals(3, cacheBuffer.getAllModels().size()); diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index fbf0d08ee..c544f0ea5 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -30,112 +30,87 @@ import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.time.Clock; -import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; -import java.util.Map.Entry; -import java.util.Optional; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.opensearch.OpenSearchException; -import org.opensearch.action.ActionListener; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.common.exception.LimitExceededException; -import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.ml.CheckpointDao; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; -public class PriorityCacheTests extends OpenSearchTestCase { +public class PriorityCacheTests extends AbstractCacheTest { private static final Logger LOG = LogManager.getLogger(PriorityCacheTests.class); - String modelId1, modelId2, modelId3, modelId4; EntityCache cacheProvider; CheckpointDao checkpoint; - MemoryTracker memoryTracker; ModelManager modelManager; - Clock clock; + ClusterService clusterService; Settings settings; - ThreadPool threadPool; - float initialPriority; - CacheBuffer cacheBuffer; - long memoryPerEntity; - String detectorId, detectorId2; - AnomalyDetector detector, detector2; + String detectorId2; + AnomalyDetector detector2; double[] point; - String entityName; int dedicatedCacheSize; - Duration detectorDuration; - int numMinSamples; @SuppressWarnings("unchecked") @Override @Before public void setUp() throws Exception { super.setUp(); - modelId1 = "1"; - modelId2 = "2"; - modelId3 = "3"; - modelId4 = "4"; - checkpoint = mock(CheckpointDao.class); - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - ActionListener>> listener = - (ActionListener>>) args[1]; - listener.onResponse(Optional.empty()); - return null; - }).when(checkpoint).restoreModelCheckpoint(anyString(), any(ActionListener.class)); - memoryTracker = mock(MemoryTracker.class); - when(memoryTracker.memoryToShed()).thenReturn(0L); + checkpoint = mock(CheckpointDao.class); modelManager = mock(ModelManager.class); - doNothing().when(modelManager).processEntityCheckpoint(any(Optional.class), anyString(), anyString(), any(ModelState.class)); - - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); clusterService = mock(ClusterService.class); - settings = Settings.EMPTY; - ClusterSettings clusterSettings = new ClusterSettings( - settings, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND))) + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays.asList(AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE) + ) + ) ); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getClusterSettings()).thenReturn(settings); - threadPool = mock(ThreadPool.class); dedicatedCacheSize = 1; - numMinSamples = 3; + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); EntityCache cache = new PriorityCache( checkpoint, @@ -143,30 +118,19 @@ public void setUp() throws Exception { AnomalyDetectorSettings.CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, memoryTracker, - modelManager, AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES, clock, clusterService, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - numMinSamples, - settings, threadPool, - // put a large value since my tests uses a lot of permits in a burst manner - 2000 + checkpointWriteQueue, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT ); cacheProvider = new CacheProvider(cache).get(); - memoryPerEntity = 81920L; when(memoryTracker.estimateModelSize(any(AnomalyDetector.class), anyInt())).thenReturn(memoryPerEntity); - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); - - detector = mock(AnomalyDetector.class); - detectorId = "123"; - when(detector.getDetectorId()).thenReturn(detectorId); - detectorDuration = Duration.ofMinutes(5); - when(detector.getDetectionIntervalDuration()).thenReturn(detectorDuration); - when(detector.getDetectorIntervalInSeconds()).thenReturn(detectorDuration.getSeconds()); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); detector2 = mock(AnomalyDetector.class); detectorId2 = "456"; @@ -174,34 +138,24 @@ public void setUp() throws Exception { when(detector2.getDetectionIntervalDuration()).thenReturn(detectorDuration); when(detector2.getDetectorIntervalInSeconds()).thenReturn(detectorDuration.getSeconds()); - cacheBuffer = new CacheBuffer( - 1, - 1, - checkpoint, - memoryPerEntity, - memoryTracker, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - detectorId - ); - - initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); point = new double[] { 0.1 }; - entityName = "1.2.3.4"; } public void testCacheHit() { - // cache miss due to empty cache - assertEquals(null, cacheProvider.get(modelId1, detector, point, entityName)); // cache miss due to door keeper - assertEquals(null, cacheProvider.get(modelId1, detector, point, entityName)); + assertEquals(null, cacheProvider.get(modelState1.getModelId(), detector)); + // cache miss due to empty cache + assertEquals(null, cacheProvider.get(modelState1.getModelId(), detector)); + cacheProvider.hostIfPossible(detector, modelState1); assertEquals(1, cacheProvider.getTotalActiveEntities()); assertEquals(1, cacheProvider.getAllModels().size()); - ModelState hitState = cacheProvider.get(modelId1, detector, point, entityName); + ModelState hitState = cacheProvider.get(modelState1.getModelId(), detector); assertEquals(detectorId, hitState.getDetectorId()); EntityModel model = hitState.getModel(); assertEquals(null, model.getRcf()); assertEquals(null, model.getThreshold()); + assertTrue(model.getSamples().isEmpty()); + modelState1.getModel().addSample(point); assertTrue(Arrays.equals(point, model.getSamples().peek())); ArgumentCaptor memoryConsumed = ArgumentCaptor.forClass(Long.class); @@ -211,23 +165,25 @@ public void testCacheHit() { verify(memoryTracker, times(1)).consumeMemory(memoryConsumed.capture(), reserved.capture(), origin.capture()); assertEquals(dedicatedCacheSize * memoryPerEntity, memoryConsumed.getValue().intValue()); assertEquals(true, reserved.getValue().booleanValue()); - assertEquals(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, origin.getValue()); + assertEquals(MemoryTracker.Origin.HC_DETECTOR, origin.getValue()); for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId2, detector, point, entityName); + cacheProvider.get(modelId2, detector); } } public void testInActiveCache() { // make modelId1 has enough priority for (int i = 0; i < 10; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + cacheProvider.get(modelId1, detector); } + assertTrue(cacheProvider.hostIfPossible(detector, modelState1)); assertEquals(1, cacheProvider.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); for (int i = 0; i < 2; i++) { - assertEquals(null, cacheProvider.get(modelId2, detector, point, entityName)); + assertEquals(null, cacheProvider.get(modelId2, detector)); } + assertTrue(false == cacheProvider.hostIfPossible(detector, modelState2)); // modelId2 gets put to inactive cache due to nothing in shared cache // and it cannot replace modelId1 assertEquals(1, cacheProvider.getActiveEntities(detectorId)); @@ -236,27 +192,46 @@ public void testInActiveCache() { public void testSharedCache() { // make modelId1 has enough priority for (int i = 0; i < 10; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + cacheProvider.get(modelId1, detector); } + cacheProvider.hostIfPossible(detector, modelState1); assertEquals(1, cacheProvider.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(true); for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId2, detector, point, entityName); + cacheProvider.get(modelId2, detector); } + cacheProvider.hostIfPossible(detector, modelState2); // modelId2 should be in shared cache assertEquals(2, cacheProvider.getActiveEntities(detectorId)); for (int i = 0; i < 10; i++) { - // put in dedicated cache - cacheProvider.get(modelId3, detector2, point, entityName); + cacheProvider.get(modelId3, detector2); } + modelState3 = new ModelState<>( + new EntityModel(entity3, new ArrayDeque<>(), null, null), + modelId3, + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + cacheProvider.hostIfPossible(detector2, modelState3); assertEquals(1, cacheProvider.getActiveEntities(detectorId2)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); for (int i = 0; i < 4; i++) { // replace modelId2 in shared cache - cacheProvider.get(modelId4, detector2, point, entityName); + cacheProvider.get(modelId4, detector2); } + modelState4 = new ModelState<>( + new EntityModel(entity4, new ArrayDeque<>(), null, null), + modelId4, + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + cacheProvider.hostIfPossible(detector2, modelState4); assertEquals(2, cacheProvider.getActiveEntities(detectorId2)); assertEquals(3, cacheProvider.getTotalActiveEntities()); assertEquals(3, cacheProvider.getAllModels().size()); @@ -270,125 +245,44 @@ public void testSharedCache() { public void testReplace() { for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + cacheProvider.get(modelState1.getModelId(), detector); } + + cacheProvider.hostIfPossible(detector, modelState1); assertEquals(1, cacheProvider.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); ModelState state = null; + for (int i = 0; i < 4; i++) { - state = cacheProvider.get(modelId2, detector, point, entityName); + cacheProvider.get(modelId2, detector); } - // modelId2 replaced modelId1 + // emptyState2 replaced emptyState2 + cacheProvider.hostIfPossible(detector, modelState2); + state = cacheProvider.get(modelId2, detector); + assertEquals(modelId2, state.getModelId()); - assertTrue(Arrays.equals(point, state.getModel().getSamples().peek())); assertEquals(1, cacheProvider.getActiveEntities(detectorId)); } public void testCannotAllocateBuffer() { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false); - expectThrows(LimitExceededException.class, () -> cacheProvider.get(modelId1, detector, point, entityName)); - } - - /** - * Test that even though we put more and more samples, there are only numMinSamples stored - */ - @SuppressWarnings("unchecked") - public void testTooManySamples() { - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - ModelState state = (ModelState) args[3]; - EntityModel model = state.getModel(); - for (int i = 0; i < 10; i++) { - model.addSample(point); - } - try { - // invalid samples cannot bed added - model.addSample(null); - model.addSample(new double[] {}); - } catch (Exception e) { - assertTrue("add invalid samples should not result in failure", false); - } - - return null; - }).when(modelManager).processEntityCheckpoint(any(Optional.class), anyString(), anyString(), any(ModelState.class)); - - ModelState state = null; - for (int i = 0; i < 10; i++) { - state = cacheProvider.get(modelId1, detector, point, entityName); - } - assertEquals(numMinSamples, state.getModel().getSamples().size()); - } - - /** - * We should have no problem when the checkpoint index does not exist yet. - */ - @SuppressWarnings("unchecked") - public void testIndexNotFoundException() { - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - ActionListener>> listener = - (ActionListener>>) args[1]; - listener.onFailure(new IndexNotFoundException("", CommonName.CHECKPOINT_INDEX_NAME)); - return null; - }).when(checkpoint).restoreModelCheckpoint(anyString(), any(ActionListener.class)); - ModelState state = null; - for (int i = 0; i < 3; i++) { - state = cacheProvider.get(modelId1, detector, point, entityName); - } - assertEquals(1, state.getModel().getSamples().size()); - } - - @SuppressWarnings("unchecked") - public void testThrottledRestore() { - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - ActionListener>> listener = - (ActionListener>>) args[1]; - listener.onFailure(new OpenSearchRejectedExecutionException("", false)); - return null; - }).when(checkpoint).restoreModelCheckpoint(anyString(), any(ActionListener.class)); - for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId1, detector, point, entityName); - } - for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId2, detector, point, entityName); - } - - // due to throttling cool down, we should only restore once - verify(checkpoint, times(1)).restoreModelCheckpoint(anyString(), any(ActionListener.class)); - } - - // we only log error for this - @SuppressWarnings("unchecked") - public void testUnexpectedRestoreError() { - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - ActionListener>> listener = - (ActionListener>>) args[1]; - listener.onFailure(new RuntimeException()); - return null; - }).when(checkpoint).restoreModelCheckpoint(anyString(), any(ActionListener.class)); - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); - for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId1, detector, point, entityName); - } - for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId2, detector, point, entityName); - } - - verify(checkpoint, times(2)).restoreModelCheckpoint(anyString(), any(ActionListener.class)); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); + expectThrows(LimitExceededException.class, () -> cacheProvider.get(modelId1, detector)); } public void testExpiredCacheBuffer() { when(clock.instant()).thenReturn(Instant.MIN); when(memoryTracker.canAllocate(anyLong())).thenReturn(true); for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + cacheProvider.get(modelId1, detector); } for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId2, detector, point, entityName); + cacheProvider.get(modelId2, detector); } + + cacheProvider.hostIfPossible(detector, modelState1); + cacheProvider.hostIfPossible(detector, modelState2); + assertEquals(2, cacheProvider.getTotalActiveEntities()); assertEquals(2, cacheProvider.getAllModels().size()); when(clock.instant()).thenReturn(Instant.now()); @@ -398,20 +292,29 @@ public void testExpiredCacheBuffer() { for (int i = 0; i < 2; i++) { // doorkeeper should have been reset - assertEquals(null, cacheProvider.get(modelId2, detector, point, entityName)); + assertEquals(null, cacheProvider.get(modelId2, detector)); } } public void testClear() { when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + // make modelId1 have higher priority + cacheProvider.get(modelId1, detector); } - for (int i = 0; i < 3; i++) { - cacheProvider.get(modelId2, detector, point, entityName); + + for (int i = 0; i < 2; i++) { + cacheProvider.get(modelId2, detector); } + + cacheProvider.hostIfPossible(detector, modelState1); + cacheProvider.hostIfPossible(detector, modelState2); + assertEquals(2, cacheProvider.getTotalActiveEntities()); assertTrue(cacheProvider.isActive(detectorId, modelId1)); + assertEquals(0, cacheProvider.getTotalUpdates(detectorId)); + modelState1.getModel().addSample(point); assertEquals(1, cacheProvider.getTotalUpdates(detectorId)); assertEquals(1, cacheProvider.getTotalUpdates(detectorId, modelId1)); cacheProvider.clear(detectorId); @@ -419,7 +322,7 @@ public void testClear() { for (int i = 0; i < 2; i++) { // doorkeeper should have been reset - assertEquals(null, cacheProvider.get(modelId2, detector, point, entityName)); + assertEquals(null, cacheProvider.get(modelId2, detector)); } } @@ -433,14 +336,19 @@ public void run() { private void setUpConcurrentMaintenance() { when(memoryTracker.canAllocate(anyLong())).thenReturn(true); for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId1, detector, point, entityName); + cacheProvider.get(modelId1, detector); } for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId2, detector, point, entityName); + cacheProvider.get(modelId2, detector); } for (int i = 0; i < 2; i++) { - cacheProvider.get(modelId3, detector, point, entityName); + cacheProvider.get(modelId3, detector); } + + cacheProvider.hostIfPossible(detector, modelState1); + cacheProvider.hostIfPossible(detector, modelState2); + cacheProvider.hostIfPossible(detector, modelState3); + when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); assertEquals(3, cacheProvider.getTotalActiveEntities()); } @@ -528,4 +436,164 @@ public void testFailedConcurrentMaintenance() throws InterruptedException { // we should return here return; } + + private void selectTestCommon(int entityFreq) { + for (int i = 0; i < entityFreq; i++) { + // bypass doorkeeper + cacheProvider.get(entity1.getModelId(detectorId).get(), detector); + } + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + List selected = selectedAndOther.getLeft(); + assertEquals(1, selected.size()); + assertEquals(entity1, selected.get(0)); + assertEquals(0, selectedAndOther.getRight().size()); + } + + public void testSelectToDedicatedCache() { + selectTestCommon(2); + } + + public void testSelectToSharedCache() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + cacheProvider.get(entity2.getModelId(detectorId).get(), detector); + } + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + + // fill in dedicated cache + cacheProvider.hostIfPossible(detector, modelState2); + selectTestCommon(2); + verify(memoryTracker, times(1)).canAllocate(anyLong()); + } + + public void testSelectToReplaceInCache() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + cacheProvider.get(entity2.getModelId(detectorId).get(), detector); + } + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + + // fill in dedicated cache + cacheProvider.hostIfPossible(detector, modelState2); + // make entity1 have enough priority to replace entity2 + selectTestCommon(10); + verify(memoryTracker, times(1)).canAllocate(anyLong()); + } + + private void replaceInOtherCacheSetUp() { + Entity entity5 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal5"); + Entity entity6 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal6"); + ModelState modelState5 = new ModelState<>( + new EntityModel(entity5, new ArrayDeque<>(), null, null), + entity5.getModelId(detectorId2).get(), + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + ModelState modelState6 = new ModelState<>( + new EntityModel(entity6, new ArrayDeque<>(), null, null), + entity6.getModelId(detectorId2).get(), + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + for (int i = 0; i < 3; i++) { + // bypass doorkeeper and leave room for lower frequency entity in testSelectToCold + cacheProvider.get(entity5.getModelId(detectorId2).get(), detector2); + cacheProvider.get(entity6.getModelId(detectorId2).get(), detector2); + } + for (int i = 0; i < 10; i++) { + // entity1 cannot replace entity2 due to frequency + cacheProvider.get(entity2.getModelId(detectorId).get(), detector); + } + // put modelState5 in dedicated and modelState6 in shared cache + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + cacheProvider.hostIfPossible(detector2, modelState5); + cacheProvider.hostIfPossible(detector2, modelState6); + + // fill in dedicated cache + cacheProvider.hostIfPossible(detector, modelState2); + + // don't allow to use shared cache afterwards + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + } + + public void testSelectToReplaceInOtherCache() { + replaceInOtherCacheSetUp(); + + // make entity1 have enough priority to replace entity2 + selectTestCommon(10); + // once when deciding whether to host modelState6; + // once when calling selectUpdateCandidate on entity1 + verify(memoryTracker, times(2)).canAllocate(anyLong()); + } + + public void testSelectToCold() { + replaceInOtherCacheSetUp(); + + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + cacheProvider.get(entity1.getModelId(detectorId).get(), detector); + } + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + List cold = selectedAndOther.getRight(); + assertEquals(1, cold.size()); + assertEquals(entity1, cold.get(0)); + assertEquals(0, selectedAndOther.getLeft().size()); + } + + /* + * Test the scenario: + * 1. A detector's buffer uses dedicated and shared memory + * 2. a new detector's buffer is created and triggers clearMemory (every new + * CacheBuffer creation will trigger it) + * 3. clearMemory found we can reclaim shared memory + */ + public void testClearMemory() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + cacheProvider.get(entity2.getModelId(detectorId).get(), detector); + } + + for (int i = 0; i < 10; i++) { + // bypass doorkeeper and make entity1 have higher frequency + cacheProvider.get(entity1.getModelId(detectorId).get(), detector); + } + + // put modelState5 in dedicated and modelState6 in shared cache + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + cacheProvider.hostIfPossible(detector, modelState1); + cacheProvider.hostIfPossible(detector, modelState2); + + // two entities get inserted to cache + assertTrue(null != cacheProvider.get(entity1.getModelId(detectorId).get(), detector)); + assertTrue(null != cacheProvider.get(entity2.getModelId(detectorId).get(), detector)); + + Entity entity5 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal5"); + when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); + for (int i = 0; i < 2; i++) { + // bypass doorkeeper, CacheBuffer created, and trigger clearMemory + cacheProvider.get(entity5.getModelId(detectorId2).get(), detector2); + } + + assertTrue(null != cacheProvider.get(entity1.getModelId(detectorId).get(), detector)); + // entity 2 removed + assertTrue(null == cacheProvider.get(entity2.getModelId(detectorId).get(), detector)); + assertTrue(null == cacheProvider.get(entity5.getModelId(detectorId2).get(), detector)); + } + + public void testSelectEmpty() { + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + assertEquals(0, selectedAndOther.getLeft().size()); + assertEquals(0, selectedAndOther.getRight().size()); + } } diff --git a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java index c802543b0..42de4e99e 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java @@ -81,18 +81,18 @@ public void testNormal() { public void testOverflow() { when(clock.instant()).thenReturn(now); tracker.updatePriority(entity1); - float priority1 = tracker.getMinimumScaledPriority().getValue(); + float priority1 = tracker.getMinimumScaledPriority().get().getValue(); // when(clock.instant()).thenReturn(now.plusSeconds(60L)); tracker.updatePriority(entity1); - float priority2 = tracker.getMinimumScaledPriority().getValue(); + float priority2 = tracker.getMinimumScaledPriority().get().getValue(); // we incremented the priority assertTrue("The following is expected: " + priority2 + " > " + priority1, priority2 > priority1); when(clock.instant()).thenReturn(now.plus(3, ChronoUnit.DAYS)); tracker.updatePriority(entity1); // overflow happens, we use increment as the new priority - assertEquals(0, tracker.getMinimumScaledPriority().getValue().floatValue(), 0.001); + assertEquals(0, tracker.getMinimumScaledPriority().get().getValue().floatValue(), 0.001); } public void testTooManyEntities() { @@ -105,4 +105,11 @@ public void testTooManyEntities() { // one entity is kicked out due to the size limit is reached. assertEquals(2, tracker.size()); } + + public void testEmptyTracker() { + assertTrue(!tracker.getMinimumScaledPriority().isPresent()); + assertTrue(!tracker.getMinimumPriority().isPresent()); + assertTrue(!tracker.getMinimumPriorityEntityId().isPresent()); + assertTrue(!tracker.getHighestPriorityEntityId().isPresent()); + } } diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java index f2ef20f45..c92cc34f7 100644 --- a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java @@ -114,7 +114,7 @@ public void setUp() throws Exception { settings = Settings .builder() - .put("opendistro.anomaly_detection.cluster_state_change_cooldown_minutes", TimeValue.timeValueMinutes(5)) + .put("plugins.anomaly_detection.cluster_state_change_cooldown_minutes", TimeValue.timeValueMinutes(5)) .build(); clock = mock(Clock.class); when(clock.millis()).thenReturn(700000L); diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index c5c4e370c..1fe0e0810 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -119,6 +119,8 @@ public class FeatureManagerTests { private FeatureManager featureManager; + private String detectorId; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -134,7 +136,8 @@ public void setup() { maxPreviewSamples = 2; featureBufferTtl = Duration.ofMillis(1_000L); - when(detector.getDetectorId()).thenReturn("id"); + detectorId = "id"; + when(detector.getDetectorId()).thenReturn(detectorId); when(detector.getShingleSize()).thenReturn(shingleSize); IntervalTimeConfiguration detectorIntervalTimeConfig = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); intervalInMilliseconds = detectorIntervalTimeConfig.toDuration().toMillis(); @@ -524,7 +527,7 @@ public void getPreviewFeatures_returnExceptionToListener_whenQueryFail() throws public void getPreviewFeatureForEntity() throws IOException { long start = 0L; long end = 240_000L; - Entity entity = new Entity("fieldName", "value"); + Entity entity = Entity.createSingleAttributeEntity(detectorId, "fieldName", "value"); List> coldStartSamples = new ArrayList<>(); coldStartSamples.add(Optional.of(new double[] { 10.0 })); @@ -552,7 +555,7 @@ public void getPreviewFeatureForEntity() throws IOException { public void getPreviewFeatureForEntity_noDataToPreview() throws IOException { long start = 0L; long end = 240_000L; - Entity entity = new Entity("fieldName", "value"); + Entity entity = Entity.createSingleAttributeEntity(detectorId, "fieldName", "value"); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(4); @@ -572,8 +575,8 @@ public void getPreviewEntities() { long start = 0L; long end = 240_000L; - Entity entity1 = new Entity("fieldName", "value1"); - Entity entity2 = new Entity("fieldName", "value2"); + Entity entity1 = Entity.createSingleAttributeEntity(detectorId, "fieldName", "value1"); + Entity entity2 = Entity.createSingleAttributeEntity(detectorId, "fieldName", "value2"); List entities = asList(entity1, entity2); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java index 33a21dbc1..5129fe27a 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java @@ -34,7 +34,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.eq; @@ -45,8 +44,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.util.AbstractMap.SimpleEntry; @@ -70,7 +67,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.util.BytesRef; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -96,7 +92,6 @@ import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.Entity; -import org.opensearch.ad.model.Feature; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.ClientUtil; @@ -122,13 +117,8 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.AggregatorFactories; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.aggregations.InternalOrder; import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.opensearch.search.aggregations.bucket.terms.StringTerms; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.InternalMax; import org.opensearch.search.aggregations.metrics.InternalMin; import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; @@ -145,12 +135,12 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; -import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; @PowerMockIgnore("javax.management.*") @RunWith(PowerMockRunner.class) @PowerMockRunnerDelegate(JUnitParamsRunner.class) -@PrepareForTest({ ParseUtils.class }) +@PrepareForTest({ ParseUtils.class, Gson.class }) public class SearchFeatureDaoTests { private final Logger LOG = LogManager.getLogger(SearchFeatureDaoTests.class); @@ -195,15 +185,13 @@ public class SearchFeatureDaoTests { @Mock private ClusterService clusterService; - private SearchSourceBuilder featureQuery = new SearchSourceBuilder(); - // private Map searchRequestParams; private SearchRequest searchRequest; private SearchSourceBuilder searchSourceBuilder; private MultiSearchRequest multiSearchRequest; private Map aggsMap; - // private List aggsList; private IntervalTimeConfiguration detectionInterval; - // private Settings settings; + private String detectorId; + private Gson gson; @Before public void setup() throws Exception { @@ -223,18 +211,16 @@ public void setup() throws Exception { Settings settings = Settings.EMPTY; ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections - .unmodifiableSet( - new HashSet<>( - Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW) - ) - ) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - searchFeatureDao = spy(new SearchFeatureDao(client, xContent, interpolator, clientUtil, threadPool, settings, clusterService)); + searchFeatureDao = spy(new SearchFeatureDao(client, xContent, interpolator, clientUtil, settings, clusterService, gson)); detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + detectorId = "123"; + + when(detector.getDetectorId()).thenReturn(detectorId); when(detector.getTimeField()).thenReturn("testTimeField"); when(detector.getIndices()).thenReturn(Arrays.asList("testIndices")); when(detector.getDetectionInterval()).thenReturn(detectionInterval); @@ -281,6 +267,8 @@ public void setup() throws Exception { ); when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); + + gson = PowerMockito.mock(Gson.class); } @Test @@ -735,156 +723,6 @@ private Entry pair(K key, V value) { return new SimpleEntry<>(key, value); } - @Test - @SuppressWarnings("unchecked") - public void testNormalGetFeaturesByEntities() throws IOException { - SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); - - String aggregationId = "deny_max"; - String featureName = "deny max"; - AggregationBuilder builder = new MaxAggregationBuilder("deny_max").field("deny"); - AggregatorFactories.Builder aggBuilder = AggregatorFactories.builder(); - aggBuilder.addAggregator(builder); - when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(aggregationId)); - when(detector.getFeatureAttributes()).thenReturn(Collections.singletonList(new Feature(aggregationId, featureName, true, builder))); - when(ParseUtils.parseAggregators(anyString(), any(), anyString())).thenReturn(aggBuilder); - - String app0Name = "app_0"; - double app0Max = 1976.0; - InternalAggregation app0Agg = new InternalMax(aggregationId, app0Max, DocValueFormat.RAW, Collections.emptyMap()); - StringTerms.Bucket app0Bucket = new StringTerms.Bucket( - new BytesRef(app0Name.getBytes(StandardCharsets.UTF_8), 0, app0Name.getBytes(StandardCharsets.UTF_8).length), - 3, - InternalAggregations.from(Collections.singletonList(app0Agg)), - false, - 0, - DocValueFormat.RAW - ); - - String app1Name = "app_1"; - double app1Max = 3604.0; - InternalAggregation app1Agg = new InternalMax(aggregationId, app1Max, DocValueFormat.RAW, Collections.emptyMap()); - StringTerms.Bucket app1Bucket = new StringTerms.Bucket( - new BytesRef(app1Name.getBytes(StandardCharsets.UTF_8), 0, app1Name.getBytes(StandardCharsets.UTF_8).length), - 3, - InternalAggregations.from(Collections.singletonList(app1Agg)), - false, - 0, - DocValueFormat.RAW - ); - - List stringBuckets = ImmutableList.of(app0Bucket, app1Bucket); - - StringTerms termsAgg = new StringTerms( - "term_agg", - InternalOrder.key(false), - BucketOrder.count(false), - 1, - 0, - Collections.emptyMap(), - DocValueFormat.RAW, - 1, - false, - 0, - stringBuckets, - 0 - ); - - InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(termsAgg)); - - SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); - - // Simulate response: - // {"took":507,"timed_out":false,"_shards":{"total":1,"successful":1, - // "skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, - // "aggregations":{"term_agg":{"doc_count_error_upper_bound":0, - // "sum_other_doc_count":0,"buckets":[{"key":"app_0","doc_count":3, - // "deny_max":{"value":1976.0}},{"key":"app_1","doc_count":3, - // "deny_max":{"value":3604.0}}]}}} - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 507, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - - doAnswer(invocation -> { - SearchRequest request = invocation.getArgument(0); - assertEquals(1, request.indices().length); - assertTrue(detector.getIndices().contains(request.indices()[0])); - AggregatorFactories.Builder aggs = request.source().aggregations(); - assertEquals(1, aggs.count()); - Collection factory = aggs.getAggregatorFactories(); - assertTrue(!factory.isEmpty()); - assertThat(factory.iterator().next(), instanceOf(TermsAggregationBuilder.class)); - - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); - - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); - - ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(listener).onResponse(captor.capture()); - Map result = captor.getValue(); - assertEquals(2, result.size()); - assertEquals(app0Max, result.get(app0Name)[0], 0.001); - assertEquals(app1Max, result.get(app1Name)[0], 0.001); - } - - @SuppressWarnings("unchecked") - @Test - public void testEmptyGetFeaturesByEntities() { - SearchResponseSections searchSections = new SearchResponseSections(null, null, null, false, false, null, 1); - - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 507, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); - - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); - - ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(listener).onResponse(captor.capture()); - Map result = captor.getValue(); - assertEquals(0, result.size()); - } - - @SuppressWarnings("unchecked") - @Test(expected = EndRunException.class) - public void testParseIOException() throws Exception { - String aggregationId = "deny_max"; - String featureName = "deny max"; - AggregationBuilder builder = new MaxAggregationBuilder("deny_max").field("deny"); - AggregatorFactories.Builder aggBuilder = AggregatorFactories.builder(); - aggBuilder.addAggregator(builder); - when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(aggregationId)); - when(detector.getFeatureAttributes()).thenReturn(Collections.singletonList(new Feature(aggregationId, featureName, true, builder))); - PowerMockito.doThrow(new IOException()).when(ParseUtils.class, "parseAggregators", anyString(), any(), anyString()); - - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); - } - @SuppressWarnings("unchecked") @Test public void testGetEntityMinMaxDataTime() { @@ -936,7 +774,8 @@ public void testGetEntityMinMaxDataTime() { }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); ActionListener, Optional>> listener = mock(ActionListener.class); - searchFeatureDao.getEntityMinMaxDataTime(detector, "app_1", listener); + Entity entity = Entity.createSingleAttributeEntity(detectorId, "field", "app_1"); + searchFeatureDao.getEntityMinMaxDataTime(detector, entity, listener); ArgumentCaptor, Optional>> captor = ArgumentCaptor.forClass(Entry.class); verify(listener).onResponse(captor.capture()); @@ -944,88 +783,4 @@ public void testGetEntityMinMaxDataTime() { assertEquals((long) earliest, result.getKey().get().longValue()); assertEquals((long) latest, result.getValue().get().longValue()); } - - @SuppressWarnings("unchecked") - @Test - public void testGetHighestCountEntities() { - SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); - - String entity1Name = "value1"; - long entity1Count = 3; - StringTerms.Bucket entity1Bucket = new StringTerms.Bucket( - new BytesRef(entity1Name.getBytes(StandardCharsets.UTF_8), 0, entity1Name.getBytes(StandardCharsets.UTF_8).length), - entity1Count, - null, - false, - 0, - DocValueFormat.RAW - ); - String entity2Name = "value2"; - long entity2Count = 1; - StringTerms.Bucket entity2Bucket = new StringTerms.Bucket( - new BytesRef(entity2Name.getBytes(StandardCharsets.UTF_8), 0, entity2Name.getBytes(StandardCharsets.UTF_8).length), - entity2Count, - null, - false, - 0, - DocValueFormat.RAW - ); - List stringBuckets = ImmutableList.of(entity1Bucket, entity2Bucket); - StringTerms termsAgg = new StringTerms( - "term_agg", - InternalOrder.key(false), - BucketOrder.count(false), - 1, - 0, - Collections.emptyMap(), - DocValueFormat.RAW, - 1, - false, - 0, - stringBuckets, - 0 - ); - - InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(termsAgg)); - - SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); - - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 30, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - - doAnswer(invocation -> { - SearchRequest request = invocation.getArgument(0); - assertEquals(1, request.indices().length); - assertTrue(detector.getIndices().contains(request.indices()[0])); - AggregatorFactories.Builder aggs = request.source().aggregations(); - assertEquals(1, aggs.count()); - Collection factory = aggs.getAggregatorFactories(); - assertTrue(!factory.isEmpty()); - assertThat(factory.iterator().next(), instanceOf(TermsAggregationBuilder.class)); - - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); - - when(detector.getCategoryField()).thenReturn(Collections.singletonList("fieldName")); - ActionListener> listener = mock(ActionListener.class); - - searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener); - - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(listener).onResponse(captor.capture()); - List result = captor.getValue(); - assertEquals(2, result.size()); - assertEquals(entity1Name, result.get(0).getValue()); - assertEquals(entity2Name, result.get(1).getValue()); - } } diff --git a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java index 8b522b8e9..d13847858 100644 --- a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java +++ b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java @@ -64,10 +64,10 @@ protected Collection> nodePlugins() { public void setup() { settings = Settings .builder() - .put("opendistro.anomaly_detection.ad_result_history_rollover_period", TimeValue.timeValueHours(12)) - .put("opendistro.anomaly_detection.ad_result_history_max_age", TimeValue.timeValueHours(24)) - .put("opendistro.anomaly_detection.ad_result_history_max_docs", 10000L) - .put("opendistro.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)) + .put("plugins.anomaly_detection.ad_result_history_rollover_period", TimeValue.timeValueHours(12)) + .put("plugins.anomaly_detection.ad_result_history_max_age", TimeValue.timeValueHours(24)) + .put("plugins.anomaly_detection.ad_result_history_max_docs", 10000L) + .put("plugins.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)) .build(); nodeFilter = new DiscoveryNodeFilterer(clusterService()); diff --git a/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java b/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java index e5da21b13..2ccab5117 100644 --- a/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java +++ b/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java @@ -99,7 +99,7 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, AnomalyDetectorSettings.MAX_PRIMARY_SHARDS diff --git a/src/test/java/org/opensearch/ad/indices/RolloverTests.java b/src/test/java/org/opensearch/ad/indices/RolloverTests.java index f8d811f79..ce7b92cad 100644 --- a/src/test/java/org/opensearch/ad/indices/RolloverTests.java +++ b/src/test/java/org/opensearch/ad/indices/RolloverTests.java @@ -72,6 +72,7 @@ public class RolloverTests extends AbstractADTest { private ClusterState clusterState; private ClusterService clusterService; private long defaultMaxDocs; + private int numberOfNodes; @Override public void setUp() throws Exception { @@ -87,7 +88,7 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, AnomalyDetectorSettings.MAX_PRIMARY_SHARDS @@ -106,6 +107,8 @@ public void setUp() throws Exception { when(adminClient.indices()).thenReturn(indicesClient); DiscoveryNodeFilterer nodeFilter = mock(DiscoveryNodeFilterer.class); + numberOfNodes = 2; + when(nodeFilter.getNumberOfEligibleDataNodes()).thenReturn(numberOfNodes); adIndices = new AnomalyDetectionIndices(client, clusterService, threadPool, settings, nodeFilter); @@ -121,7 +124,7 @@ public void setUp() throws Exception { return null; }).when(clusterAdminClient).state(any(), any()); - defaultMaxDocs = AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.getDefault(Settings.EMPTY); + defaultMaxDocs = AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.getDefault(Settings.EMPTY); } private void assertRolloverRequest(RolloverRequest request) { @@ -129,7 +132,7 @@ private void assertRolloverRequest(RolloverRequest request) { Map> conditions = request.getConditions(); assertEquals(1, conditions.size()); - assertEquals(new MaxDocsCondition(defaultMaxDocs), conditions.get(MaxDocsCondition.NAME)); + assertEquals(new MaxDocsCondition(defaultMaxDocs * numberOfNodes), conditions.get(MaxDocsCondition.NAME)); CreateIndexRequest createIndexRequest = request.getCreateIndexRequest(); assertEquals(AnomalyDetectionIndices.AD_RESULT_HISTORY_INDEX_PATTERN, createIndexRequest.index()); @@ -169,7 +172,7 @@ public void testRolledOverButNotDeleted() { Map> conditions = request.getConditions(); assertEquals(1, conditions.size()); - assertEquals(new MaxDocsCondition(defaultMaxDocs), conditions.get(MaxDocsCondition.NAME)); + assertEquals(new MaxDocsCondition(defaultMaxDocs * numberOfNodes), conditions.get(MaxDocsCondition.NAME)); CreateIndexRequest createIndexRequest = request.getCreateIndexRequest(); assertEquals(AnomalyDetectionIndices.AD_RESULT_HISTORY_INDEX_PATTERN, createIndexRequest.index()); @@ -208,7 +211,7 @@ public void testRolledOverDeleted() { Map> conditions = request.getConditions(); assertEquals(1, conditions.size()); - assertEquals(new MaxDocsCondition(defaultMaxDocs), conditions.get(MaxDocsCondition.NAME)); + assertEquals(new MaxDocsCondition(defaultMaxDocs * numberOfNodes), conditions.get(MaxDocsCondition.NAME)); CreateIndexRequest createIndexRequest = request.getCreateIndexRequest(); assertEquals(AnomalyDetectionIndices.AD_RESULT_HISTORY_INDEX_PATTERN, createIndexRequest.index()); diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java index 940efe8d3..43f639bec 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java @@ -30,11 +30,11 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyObject; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -56,6 +56,8 @@ import java.util.Optional; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import org.apache.logging.log4j.LogManager; @@ -81,14 +83,18 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.AnomalyDetectionIndices; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.shard.ShardId; import org.powermock.api.mockito.PowerMockito; @@ -96,6 +102,7 @@ import org.powermock.modules.junit4.PowerMockRunner; import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; import com.google.gson.Gson; @@ -136,7 +143,8 @@ public class CheckpointDaoTests { private Gson gson; private Class thresholdingModelClass; - private int maxBulkSize; + + private int maxCheckpointBytes = 1_000_000; @Before public void setup() { @@ -150,8 +158,6 @@ public void setup() { when(clock.instant()).thenReturn(Instant.now()); - maxBulkSize = 10; - checkpointDao = new CheckpointDao( client, clientUtil, @@ -159,11 +165,8 @@ public void setup() { gson, rcfSerde, thresholdingModelClass, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, indexUtil, - maxBulkSize, - 200.0 + maxCheckpointBytes ); when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); @@ -438,282 +441,208 @@ public void deleteModelCheckpoint_callListener_whenCompleted() { assertEquals(null, response); } - private BulkResponse createBulkResponse(int succeeded, int failed, String[] failedId) { - BulkItemResponse[] bulkItemResponses = new BulkItemResponse[succeeded + failed]; - - ShardId shardId = new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "", 1); - int i = 0; - for (; i < failed; i++) { - bulkItemResponses[i] = new BulkItemResponse( - i, - DocWriteRequest.OpType.UPDATE, - new BulkItemResponse.Failure( - CommonName.CHECKPOINT_INDEX_NAME, - CommonName.MAPPING_TYPE, - failedId[i], - new VersionConflictEngineException(shardId, "id", "test") - ) - ); - } - - for (; i < failed + succeeded; i++) { - bulkItemResponses[i] = new BulkItemResponse( - i, - DocWriteRequest.OpType.UPDATE, - new UpdateResponse(shardId, CommonName.MAPPING_TYPE, "1", 0L, 1L, 1L, DocWriteResponse.Result.CREATED) - ); - } - - return new BulkResponse(bulkItemResponses, 507); - } - @SuppressWarnings("unchecked") @Test - public void flush_less_than_1k() { - int writeRequests = maxBulkSize - 1; - for (int i = 0; i < writeRequests; i++) { - ModelState state = MLUtil.randomModelState(); - checkpointDao.write(state, state.getModelId(), true); - } - - doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(writeRequests, request.numberOfActions()); - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); - return null; - }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - - checkpointDao.flush(); + public void restore() throws IOException { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + EntityModel modelToSave = state.getModel(); - verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - } + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + new Gson(), + new RandomCutForestSerDe(), + thresholdingModelClass, + indexUtil, + maxCheckpointBytes + ); - @SuppressWarnings("unchecked") - public void flush_more_than_1k() { - int writeRequests = maxBulkSize + 1; + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + Map source = new HashMap<>(); + source.put(CheckpointDao.DETECTOR_ID, state.getDetectorId()); + source.put(CheckpointDao.FIELD_MODEL, checkpointDao.toCheckpoint(modelToSave)); + source.put(CheckpointDao.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); + when(getResponse.getSource()).thenReturn(source); doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(maxBulkSize, request.numberOfActions()); - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); - listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + listener.onResponse(getResponse); return null; - }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); - for (int i = 0; i < writeRequests; i++) { - ModelState state = MLUtil.randomModelState(); - // should trigger auto flush - checkpointDao.write(state, state.getModelId(), true); - } + ActionListener>> listener = mock(ActionListener.class); + checkpointDao.restoreModelCheckpoint(modelId, listener); - verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - } + ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional> response = responseCaptor.getValue(); + assertTrue(response.isPresent()); + Entry entry = response.get(); + OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); + assertEquals(2020, utcTime.getYear()); + assertEquals(Month.OCTOBER, utcTime.getMonth()); + assertEquals(11, utcTime.getDayOfMonth()); + assertEquals(22, utcTime.getHour()); + assertEquals(58, utcTime.getMinute()); + assertEquals(23, utcTime.getSecond()); - @Test - public void flush_more_than_1k_has_index() { - flush_more_than_1k(); + EntityModel model = entry.getKey(); + Queue queue = model.getSamples(); + Queue samplesToSave = modelToSave.getSamples(); + assertEquals(samplesToSave.size(), queue.size()); + assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); + logger.info(modelToSave.getRcf()); + logger.info(model.getRcf()); + assertEquals(modelToSave.getRcf().getTotalUpdates(), model.getRcf().getTotalUpdates()); + assertTrue(model.getThreshold() != null); } @Test - public void flush_more_than_1k_no_index() { + public void batch_write_no_index() { when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + checkpointDao.batchWrite(new BulkRequest(), null); + verify(indexUtil, times(1)).initCheckpointIndex(any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); listener.onResponse(new CreateIndexResponse(true, true, CommonName.CHECKPOINT_INDEX_NAME)); return null; }).when(indexUtil).initCheckpointIndex(any()); - - flush_more_than_1k(); + checkpointDao.batchWrite(new BulkRequest(), null); + verify(clientUtil, times(1)).execute(any(), any(), any()); } @Test - public void flush_more_than_1k_race_condition() { + public void batch_write_index_init_no_ack() throws InterruptedException { when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onFailure(new ResourceAlreadyExistsException(CommonName.CHECKPOINT_INDEX_NAME)); + listener.onResponse(new CreateIndexResponse(false, false, CommonName.CHECKPOINT_INDEX_NAME)); return null; }).when(indexUtil).initCheckpointIndex(any()); - flush_more_than_1k(); + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao.batchWrite(new BulkRequest(), ActionListener.wrap(response -> assertTrue(false), e -> { + assertTrue(e.getMessage(), e != null); + processingLatch.countDown(); + })); + + processingLatch.await(100, TimeUnit.SECONDS); } - @SuppressWarnings("unchecked") @Test - public void flush_more_than_1k_unexpected_exception() { + public void batch_write_index_already_exists() { when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onFailure(new RuntimeException("")); + listener.onFailure(new ResourceAlreadyExistsException("blah")); return null; }).when(indexUtil).initCheckpointIndex(any()); - verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + checkpointDao.batchWrite(new BulkRequest(), null); + verify(clientUtil, times(1)).execute(any(), any(), any()); } - @SuppressWarnings("unchecked") @Test - public void bulk_has_failure() throws InterruptedException { - int writeRequests = maxBulkSize - 1; - int failureCount = 1; - String[] failedId = new String[failureCount]; - for (int i = 0; i < writeRequests; i++) { - ModelState state = MLUtil.randomModelState(); - checkpointDao.write(state, state.getModelId(), true); - if (i < failureCount) { - failedId[i] = state.getModelId(); - } - } + public void batch_write_init_exception() throws InterruptedException { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(writeRequests, request.numberOfActions()); - ActionListener listener = invocation.getArgument(2); - - listener.onResponse(createBulkResponse(request.numberOfActions(), failureCount, failedId)); + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("blah")); return null; - }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + }).when(indexUtil).initCheckpointIndex(any()); - checkpointDao.flush(); + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao.batchWrite(new BulkRequest(), ActionListener.wrap(response -> assertTrue(false), e -> { + assertTrue(e.getMessage(), e != null); + processingLatch.countDown(); + })); - doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(failureCount, request.numberOfActions()); - ActionListener listener = invocation.getArgument(2); + processingLatch.await(100, TimeUnit.SECONDS); + } - listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); - return null; - }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + private BulkResponse createBulkResponse(int succeeded, int failed, String[] failedId) { + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[succeeded + failed]; + + ShardId shardId = new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "", 1); + int i = 0; + for (; i < failed; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new BulkItemResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + CommonName.MAPPING_TYPE, + failedId[i], + new VersionConflictEngineException(shardId, "id", "test") + ) + ); + } - checkpointDao.flush(); + for (; i < failed + succeeded; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new UpdateResponse(shardId, CommonName.MAPPING_TYPE, "1", 0L, 1L, 1L, DocWriteResponse.Result.CREATED) + ); + } - verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + return new BulkResponse(bulkItemResponses, 507); } @SuppressWarnings("unchecked") @Test - public void bulk_all_failure() throws InterruptedException { - int writeRequests = maxBulkSize - 1; - for (int i = 0; i < writeRequests; i++) { - ModelState state = MLUtil.randomModelState(); - checkpointDao.write(state, state.getModelId(), true); - } - - doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(writeRequests, request.numberOfActions()); - ActionListener listener = invocation.getArgument(2); - - listener.onFailure(new RuntimeException("")); - return null; - }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - - checkpointDao.flush(); + public void batch_write_no_init() throws InterruptedException { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); doAnswer(invocation -> { - BulkRequest request = invocation.getArgument(1); - assertEquals(writeRequests, request.numberOfActions()); ActionListener listener = invocation.getArgument(2); - listener.onResponse(createBulkResponse(request.numberOfActions(), 0, null)); + listener.onResponse(createBulkResponse(2, 0, null)); return null; }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - checkpointDao.flush(); - - verify(clientUtil, times(2)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - } - - @SuppressWarnings("unchecked") - @Test - public void checkpoint_saved_less_than_1_hr() { - ModelState state = MLUtil.randomModelState(); - state.setLastCheckpointTime(Instant.now()); - checkpointDao.write(state, state.getModelId()); - - checkpointDao.flush(); - - verify(clientUtil, never()).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); - } - - @SuppressWarnings("unchecked") - @Test - public void checkpoint_coldstart_checkpoint() { - ModelState state = MLUtil.randomModelState(); - state.setLastCheckpointTime(Instant.now()); - // cold start checkpoint will save whatever - checkpointDao.write(state, state.getModelId(), true); - - checkpointDao.flush(); + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao + .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); - verify(clientUtil, times(1)).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + // we don't expect the waiting time elapsed before the count reached zero + assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); + verify(clientUtil, times(1)).execute(any(), any(), any()); } @SuppressWarnings("unchecked") @Test - public void restore() throws IOException { - ModelState state = MLUtil.randomNonEmptyModelState(); - EntityModel modelToSave = state.getModel(); - - checkpointDao = new CheckpointDao( - client, - clientUtil, - indexName, - new Gson(), - new RandomCutForestSerDe(), - thresholdingModelClass, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - indexUtil, - maxBulkSize, - 2 - ); - - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - Map source = new HashMap<>(); - source.put(CheckpointDao.DETECTOR_ID, state.getDetectorId()); - source.put(CheckpointDao.FIELD_MODEL, checkpointDao.toCheckpoint(modelToSave)); - source.put(CheckpointDao.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); - when(getResponse.getSource()).thenReturn(source); - + public void batch_read() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); - listener.onResponse(getResponse); + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + "modelId", + new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME) + ) + ); + listener.onResponse(new MultiGetResponse(items)); return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); + }).when(clientUtil).execute(eq(MultiGetAction.INSTANCE), any(MultiGetRequest.class), any(ActionListener.class)); - ActionListener>> listener = mock(ActionListener.class); - checkpointDao.restoreModelCheckpoint(modelId, listener); + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao + .batchRead(new MultiGetRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); - ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); - verify(listener).onResponse(responseCaptor.capture()); - Optional> response = responseCaptor.getValue(); - assertTrue(response.isPresent()); - Entry entry = response.get(); - OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); - assertEquals(2020, utcTime.getYear()); - assertEquals(Month.OCTOBER, utcTime.getMonth()); - assertEquals(11, utcTime.getDayOfMonth()); - assertEquals(22, utcTime.getHour()); - assertEquals(58, utcTime.getMinute()); - assertEquals(23, utcTime.getSecond()); - - EntityModel model = entry.getKey(); - Queue queue = model.getSamples(); - Queue samplesToSave = modelToSave.getSamples(); - assertEquals(samplesToSave.size(), queue.size()); - assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); - logger.info(modelToSave.getRcf()); - logger.info(model.getRcf()); - assertEquals(modelToSave.getRcf().getTotalUpdates(), model.getRcf().getTotalUpdates()); - assertTrue(model.getThreshold() != null); + // we don't expect the waiting time elapsed before the count reached zero + assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); + verify(clientUtil, times(1)).execute(any(), any(), any()); } } diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java index 77a01b66f..b705ff0a6 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java @@ -32,7 +32,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import java.time.Clock; import java.util.Arrays; import java.util.Collections; @@ -43,7 +42,6 @@ import org.opensearch.ad.AbstractADTest; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.indices.AnomalyDetectionIndices; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.index.IndexNotFoundException; @@ -73,9 +71,9 @@ private enum DeleteExecutionMode { private ClientUtil clientUtil; private Gson gson; private RandomCutForestSerDe rcfSerde; - private Clock clock; private AnomalyDetectionIndices indexUtil; private String detectorId; + private int maxCheckpointBytes; @Override @Before @@ -87,9 +85,9 @@ public void setUp() throws Exception { clientUtil = mock(ClientUtil.class); gson = null; rcfSerde = mock(RandomCutForestSerDe.class); - clock = mock(Clock.class); indexUtil = mock(AnomalyDetectionIndices.class); detectorId = "123"; + maxCheckpointBytes = 1_000_000; checkpointDao = new CheckpointDao( client, @@ -98,11 +96,8 @@ public void setUp() throws Exception { gson, rcfSerde, HybridThresholdingModel.class, - clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, indexUtil, - AnomalyDetectorSettings.MAX_BULK_CHECKPOINT_SIZE, - AnomalyDetectorSettings.CHECKPOINT_BULK_PER_SECOND + maxCheckpointBytes ); } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index de7b9bea6..b2c88bc6a 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -31,7 +31,6 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -46,6 +45,9 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; @@ -63,7 +65,9 @@ import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; @@ -91,6 +95,12 @@ public class EntityColdStarterTests extends AbstractADTest { FeatureManager featureManager; Settings settings; ThreadPool threadPool; + AtomicBoolean released; + Runnable releaseSemaphore; + ActionListener listener; + CountDownLatch inProgressLatch; + CheckpointWriteWorker checkpointWriteQueue; + Entity entity; @SuppressWarnings("unchecked") @Override @@ -152,6 +162,8 @@ public void setUp() throws Exception { AnomalyDetectorPlugin.AD_THREAD_POOL_NAME ); + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + entityColdStarter = new EntityColdStarter( clock, threadPool, @@ -172,33 +184,49 @@ public void setUp() throws Exception { AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES, featureManager, + settings, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.MAX_SMALL_STATES, - checkpoint, - settings + checkpointWriteQueue ); detectorId = "123"; modelId = "123_entity_abc"; entityName = "abc"; priority = 0.3f; + entity = Entity.createSingleAttributeEntity(detectorId, "field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + } + + private void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); } // train using samples directly - public void testTrainUsingSamples() { + public void testTrainUsingSamples() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(numMinSamples); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); RandomCutForest forest = model.getRcf(); assertTrue(forest != null); assertEquals(numMinSamples, forest.getTotalUpdates()); assertTrue(model.getThreshold() != null); + + checkSemaphoreRelease(); } public void testColdStart() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -217,7 +245,7 @@ public void testColdStart() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); waitForColdStartFinish(); RandomCutForest forest = model.getRcf(); @@ -228,17 +256,20 @@ public void testColdStart() throws InterruptedException, IOException { // sleep 1 secs to give time for the last timestamp record to expire when superShortLastColdStartTimeState = true Thread.sleep(1000L); + checkSemaphoreRelease(); + released.set(false); // too frequent cold start of the same detector will fail samples = MLUtil.createQueueSamples(1); - model = new EntityModel(modelId, samples, null, null); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + model = new EntityModel(entity, samples, null, null); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); waitForColdStartFinish(); forest = model.getRcf(); assertTrue(forest == null); assertTrue(model.getThreshold() == null); + checkSemaphoreRelease(); } private void waitForColdStartFinish() throws InterruptedException { @@ -251,23 +282,10 @@ private void waitForColdStartFinish() throws InterruptedException { } } - // cold start running, return immediately - public void testColdStartRunning() { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - - NodeStateManager spyNodeStateManager = spy(stateManager); - spyNodeStateManager.markColdStartRunning(detectorId); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); - - verify(spyNodeStateManager, never()).getAnomalyDetector(any(), any()); - } - // min max: miss one - public void testMissMin() throws IOException { + public void testMissMin() throws IOException, InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -276,19 +294,20 @@ public void testMissMin() throws IOException { return null; }).when(searchFeatureDao).getEntityMinMaxDataTime(any(), any(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); RandomCutForest forest = model.getRcf(); assertTrue(forest == null); assertTrue(model.getThreshold() == null); + checkSemaphoreRelease(); } // two segments of samples, one segment has 3 samples, while another one has only 1 public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -309,7 +328,7 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); int maxWaitTimes = 20; int i = 0; @@ -324,12 +343,13 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc // 2nd segment: 1 assertEquals(130, forest.getTotalUpdates()); assertTrue(model.getThreshold() != null); + checkSemaphoreRelease(); } // two segments of samples, one segment has 3 samples, while another one 2 samples public void testTwoSegments() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -351,7 +371,7 @@ public void testTwoSegments() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); int maxWaitTimes = 20; int i = 0; @@ -366,11 +386,12 @@ public void testTwoSegments() throws InterruptedException, IOException { // 2nd segment: maxSampleStride * (continuousSampledArray.length - 1) + 1 = 64 * 1 + 1 = 65 assertEquals(194, forest.getTotalUpdates()); assertTrue(model.getThreshold() != null); + checkSemaphoreRelease(); } - public void testThrottledColdStart() { + public void testThrottledColdStart() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -379,17 +400,18 @@ public void testThrottledColdStart() { return null; }).when(searchFeatureDao).getEntityMinMaxDataTime(any(), any(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); - entityColdStarter.trainModel(samples, modelId, entityName, "456", modelState); + entityColdStarter.trainModel(entity, "456", modelState, listener); // only the first one makes the call verify(searchFeatureDao, times(1)).getEntityMinMaxDataTime(any(), any(), any()); + checkSemaphoreRelease(); } - public void testColdStartException() { + public void testColdStartException() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -398,14 +420,15 @@ public void testColdStartException() { return null; }).when(searchFeatureDao).getEntityMinMaxDataTime(any(), any(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); assertTrue(stateManager.getLastDetectionError(detectorId) != null); + checkSemaphoreRelease(); } public void testNotEnoughSamples() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -423,7 +446,7 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); int maxWaitTimes = 20; int i = 0; @@ -441,7 +464,7 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { public void testEmptyDataRange() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(modelId, samples, null, null); + EntityModel model = new EntityModel(entity, samples, null, null); modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { @@ -450,7 +473,7 @@ public void testEmptyDataRange() throws InterruptedException { return null; }).when(searchFeatureDao).getEntityMinMaxDataTime(any(), any(), any()); - entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); waitForColdStartFinish(); assertTrue(model.getRcf() == null); diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index 82b4aae43..8ea4aa9ff 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -71,6 +71,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.common.exception.ResourceNotFoundException; @@ -91,8 +92,6 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; -import test.org.opensearch.ad.util.MLUtil; - import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.returntypes.DiVector; import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; @@ -187,6 +186,9 @@ public class ModelManagerTests { private MemoryTracker memoryTracker; private Instant now; + @Mock + private ADCircuitBreakerService adCircuitBreakerService; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -226,7 +228,7 @@ public void setup() { gson = PowerMockito.mock(Gson.class); - settings = Settings.builder().put("opendistro.anomaly_detection.model_max_size_percent", modelMaxSizePercentage).build(); + settings = Settings.builder().put("plugins.anomaly_detection.model_max_size_percent", modelMaxSizePercentage).build(); ClusterSettings clusterSettings = PowerMockito.mock(ClusterSettings.class); clusterService = new ClusterService(settings, clusterSettings, null); MemoryTracker memoryTracker = new MemoryTracker( @@ -234,7 +236,8 @@ public void setup() { modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, - numSamples + numSamples, + adCircuitBreakerService ); ExecutorService executorService = mock(ExecutorService.class); @@ -411,7 +414,14 @@ public void getPartitionedForestSizes_returnExpected( ) { when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(heapSize); MemoryTracker memoryTracker = spy( - new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, numSamples) + new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + numSamples, + adCircuitBreakerService + ) ); when(memoryTracker.estimateModelSize(rcf)).thenReturn(totalModelSize); @@ -441,7 +451,14 @@ public void getPartitionedForestSizes_throwLimitExceeded( ) { when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(heapSize); MemoryTracker memoryTracker = spy( - new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, numSamples) + new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + numSamples, + adCircuitBreakerService + ) ); when(memoryTracker.estimateModelSize(rcf)).thenReturn(totalModelSize); modelPartitioner = spy(new ModelPartitioner(numSamples, numTrees, nodeFilter, memoryTracker)); @@ -528,7 +545,8 @@ public void getRcfResult_throwToListener_whenHeapLimitExceed() { modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, - numSamples + numSamples, + adCircuitBreakerService ); ActionListener listener = mock(ActionListener.class); @@ -1208,56 +1226,4 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { public void getPreviewResults_throwIllegalArgument_forInvalidInput() { modelManager.getPreviewResults(new double[0][0]); } - - @Test - public void getNullState() { - assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity("", new double[] {}, "", null, "")); - } - - @Test - public void getEmptyStateFullSamples() { - ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples); - assertEquals( - new ThresholdingResult(0, 0, 0), - modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) - ); - assertEquals(numMinSamples, state.getModel().getSamples().size()); - } - - @Test - public void getEmptyStateNotFullSamples() { - ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); - assertEquals( - new ThresholdingResult(0, 0, 0), - modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) - ); - assertEquals(numMinSamples, state.getModel().getSamples().size()); - } - - @Test - public void scoreSamples() { - ModelState state = MLUtil.randomNonEmptyModelState(); - modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId); - assertEquals(0, state.getModel().getSamples().size()); - assertEquals(now, state.getLastUsedTime()); - } - - @Test - public void processEmptyCheckpoint() { - ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); - modelManager.processEntityCheckpoint(Optional.empty(), modelId, entityName, modelState); - assertEquals(now.minus(checkpointInterval), modelState.getLastCheckpointTime()); - } - - @Test - public void processNonEmptyCheckpoint() { - EntityModel model = MLUtil.createNonEmptyModel(modelId); - ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples); - Instant checkpointTime = Instant.ofEpochMilli(1000); - modelManager - .processEntityCheckpoint(Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), modelId, entityName, modelState); - assertEquals(checkpointTime, modelState.getLastCheckpointTime()); - assertEquals(0, modelState.getModel().getSamples().size()); - assertEquals(now, modelState.getLastUsedTime()); - } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index c06f6cc68..7e9369231 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -93,7 +93,8 @@ public void testParseAnomalyDetectorWithoutNormalResult() throws IOException { randomAlphaOfLength(5), null, TestHelpers.randomUser(), - CommonValue.NO_SCHEMA_VERSION + CommonValue.NO_SCHEMA_VERSION, + null ); String detectResultString = TestHelpers .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -131,7 +132,8 @@ public void testParseAnomalyDetectorWithNanAnomalyResult() throws IOException { randomAlphaOfLength(5), null, null, - CommonValue.NO_SCHEMA_VERSION + CommonValue.NO_SCHEMA_VERSION, + null ); String detectResultString = TestHelpers .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -165,7 +167,7 @@ public void testParseAnomalyDetectorWithTaskId() throws IOException { } public void testParseAnomalyDetectorWithEntity() throws IOException { - AnomalyResult detectResult = TestHelpers.randomMultiEntityAnomalyDetectResult(0.8, 0.5); + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, 0.5); String detectResultString = TestHelpers .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); detectResultString = detectResultString @@ -193,7 +195,7 @@ public void testSerializeAnomalyResultWithoutUser() throws IOException { } public void testSerializeAnomalyResultWithEntity() throws IOException { - AnomalyResult detectResult = TestHelpers.randomMultiEntityAnomalyDetectResult(0.8, 0.5); + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, 0.5); BytesStreamOutput output = new BytesStreamOutput(); detectResult.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); diff --git a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java index 5af505f8d..56697d6c8 100644 --- a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java @@ -27,7 +27,7 @@ package org.opensearch.ad.model; import static java.util.Arrays.asList; -import static org.opensearch.ad.TestHelpers.randomMutlEntityAnomalyDetectResult; +import static org.opensearch.ad.TestHelpers.randomHCADAnomalyDetectResult; import java.util.ArrayList; import java.util.List; @@ -40,8 +40,8 @@ public class EntityAnomalyResultTests extends OpenSearchTestCase { @Test public void testGetAnomalyResults() { - AnomalyResult anomalyResult1 = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); - AnomalyResult anomalyResult2 = randomMutlEntityAnomalyDetectResult(0.5, 0.5, "error"); + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); List anomalyResults = new ArrayList() { { add(anomalyResult1); @@ -55,8 +55,8 @@ public void testGetAnomalyResults() { @Test public void testMerge() { - AnomalyResult anomalyResult1 = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); - AnomalyResult anomalyResult2 = randomMutlEntityAnomalyDetectResult(0.5, 0.5, "error"); + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); EntityAnomalyResult entityAnomalyResult1 = new EntityAnomalyResult(new ArrayList() { { @@ -75,7 +75,7 @@ public void testMerge() { @Test public void testMerge_null() { - AnomalyResult anomalyResult = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { { @@ -90,7 +90,7 @@ public void testMerge_null() { @Test public void testMerge_self() { - AnomalyResult anomalyResult = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { { @@ -106,7 +106,7 @@ public void testMerge_self() { @Test public void testMerge_otherClass() { ADStatsResponse adStatsResponse = new ADStatsResponse(); - AnomalyResult anomalyResult = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { { diff --git a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java index 48d897650..a089c34d1 100644 --- a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java @@ -41,16 +41,16 @@ public class EntityProfileTests extends AbstractADTest { public void testMerge() { - EntityProfile profile1 = new EntityProfile(null, null, null, -1, -1, null, null, EntityState.INIT); + EntityProfile profile1 = new EntityProfile(null, -1, -1, null, null, EntityState.INIT); - EntityProfile profile2 = new EntityProfile(null, null, null, -1, -1, null, null, EntityState.UNKNOWN); + EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); profile1.merge(profile2); assertEquals(profile1.getState(), EntityState.INIT); } public void testToXContent() throws IOException, JsonPathNotFoundException { - EntityProfile profile1 = new EntityProfile(null, null, null, -1, -1, null, null, EntityState.INIT); + EntityProfile profile1 = new EntityProfile(null, -1, -1, null, null, EntityState.INIT); XContentBuilder builder = jsonBuilder(); profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -58,7 +58,7 @@ public void testToXContent() throws IOException, JsonPathNotFoundException { assertEquals("INIT", JsonDeserializer.getTextValue(json, CommonName.STATE)); - EntityProfile profile2 = new EntityProfile(null, null, null, -1, -1, null, null, EntityState.UNKNOWN); + EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); builder = jsonBuilder(); profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/src/test/java/org/opensearch/ad/model/EntityTests.java b/src/test/java/org/opensearch/ad/model/EntityTests.java new file mode 100644 index 000000000..95d2384f1 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/EntityTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.TreeMap; + +import org.opensearch.ad.AbstractADTest; + +public class EntityTests extends AbstractADTest { + /** + * Test that toStrign has no random string, but only attributes + */ + public void testToString() { + TreeMap attributes = new TreeMap<>(); + String name1 = "host"; + String val1 = "server_2"; + String name2 = "service"; + String val2 = "app_4"; + attributes.put(name1, val1); + attributes.put(name2, val2); + String detectorId = "detectorId"; + Entity entity = Entity.createEntityFromOrderedMap(detectorId, attributes); + assertEquals("host=server_2,service=app_4", entity.toString()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java new file mode 100644 index 000000000..08d24b3dc --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Instant; +import java.util.Arrays; +import java.util.Optional; + +import org.opensearch.action.ActionListener; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.threadpool.ThreadPool; + +public class AbstractRateLimitingTest extends AbstractADTest { + Clock clock; + AnomalyDetector detector; + NodeStateManager nodeStateManager; + String detectorId; + String categoryField; + Entity entity, entity2, entity3; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + categoryField = "a"; + detectorId = "123"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); + + nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + entity = Entity.createSingleAttributeEntity(detectorId, categoryField, "value"); + entity2 = Entity.createSingleAttributeEntity(detectorId, categoryField, "value2"); + entity3 = Entity.createSingleAttributeEntity(detectorId, categoryField, "value3"); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java new file mode 100644 index 000000000..6f95886ee --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -0,0 +1,817 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static java.util.AbstractMap.SimpleImmutableEntry; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.mockito.Mockito; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.common.exception.LimitExceededException; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.RestStatus; +import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.threadpool.ThreadPoolStats.Stats; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.fasterxml.jackson.core.JsonParseException; + +public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { + CheckpointReadWorker worker; + + CheckpointDao checkpoint; + ClusterService clusterService; + + ModelState state; + + CheckpointWriteWorker checkpointWriteQueue; + ModelManager modelManager; + EntityColdStartWorker coldstartQueue; + ResultWriteWorker resultWriteQueue; + AnomalyDetectionIndices anomalyDetectionIndices; + CacheProvider cacheProvider; + EntityCache entityCache; + EntityFeatureRequest request, request2, request3; + ClusterSettings clusterSettings; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + checkpoint = mock(CheckpointDao.class); + + Map.Entry entry = new SimpleImmutableEntry(state.getModel(), Instant.now()); + when(checkpoint.processGetResponse(any(), anyString())).thenReturn(Optional.of(entry)); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + modelManager = mock(ModelManager.class); + when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString())).thenReturn(state); + when(modelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 0.7)); + + coldstartQueue = mock(EntityColdStartWorker.class); + resultWriteQueue = mock(ResultWriteWorker.class); + anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); + + cacheProvider = mock(CacheProvider.class); + entityCache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(entityCache); + when(entityCache.hostIfPossible(any(), any())).thenReturn(true); + + // Integer.MAX_VALUE makes a huge heap + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue + ); + + request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity, new double[] { 0 }, 0); + request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + request3 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity3, new double[] { 0 }, 0); + } + + static class RegularSetUpConfig { + private final boolean canHostModel; + private final boolean fullModel; + + RegularSetUpConfig(Builder builder) { + this.canHostModel = builder.canHostModel; + this.fullModel = builder.fullModel; + } + + public static class Builder { + boolean canHostModel = true; + boolean fullModel = true; + + Builder canHostModel(boolean canHostModel) { + this.canHostModel = canHostModel; + return this; + } + + Builder fullModel(boolean fullModel) { + this.fullModel = fullModel; + return this; + } + + public RegularSetUpConfig build() { + return new RegularSetUpConfig(this); + } + } + } + + private void regularTestSetUp(RegularSetUpConfig config) { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + when(entityCache.hostIfPossible(any(), any())).thenReturn(config.canHostModel); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); + when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString())).thenReturn(state); + + List requests = new ArrayList<>(); + requests.add(request); + worker.putAll(requests); + } + + public void testRegular() { + regularTestSetUp(new RegularSetUpConfig.Builder().build()); + + verify(resultWriteQueue, times(1)).put(any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + } + + public void testCannotLoadModel() { + regularTestSetUp(new RegularSetUpConfig.Builder().canHostModel(false).build()); + + verify(resultWriteQueue, times(1)).put(any()); + verify(checkpointWriteQueue, times(1)).write(any(), anyBoolean(), any()); + } + + public void testNoFullModel() { + regularTestSetUp(new RegularSetUpConfig.Builder().fullModel(false).build()); + verify(resultWriteQueue, never()).put(any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + } + + public void testIndexNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME) + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, times(1)).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testAllDocNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity2.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + + verify(coldstartQueue, times(2)).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testSingleDocNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity2.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + verify(coldstartQueue, times(1)).put(any()); + verify(entityCache, times(1)).hostIfPossible(any(), any()); + } + + public void testTimeout() { + AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + if (!retried.get()) { + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT) + ) + ); + items[1] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity2.getModelId(detectorId).get(), + new OpenSearchStatusException("blah", RestStatus.CONFLICT) + ) + ); + retried.set(true); + } else { + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity2.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + } + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + // two retried requests and the original putAll trigger 3 batchRead in total. + // It is possible the two retries requests get combined into one batchRead + verify(checkpoint, Mockito.atLeast(2)).batchRead(any(), any()); + assertTrue(retried.get()); + } + + public void testOverloadedExceptionFromResponse() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + new OpenSearchRejectedExecutionException("blah") + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + worker.put(request); + // the 2nd put won't trigger batchRead as we are in cool down mode + verify(checkpoint, times(1)).batchRead(any(), any()); + } + + public void testOverloadedExceptionFromFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah")); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + worker.put(request); + // the 2nd put won't trigger batchRead as we are in cool down mode + verify(checkpoint, times(1)).batchRead(any(), any()); + } + + public void testUnexpectedException() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + new IllegalArgumentException("blah") + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testRetryableException() { + AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + // not retryable + listener.onFailure(new JsonParseException(null, "blah")); + } else { + // retryable + retried.set(true); + listener.onFailure(new OpenSearchException("blah")); + } + + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + assertTrue(retried.get()); + } + + public void testRemoveUnusedQueues() { + // do nothing when putting a request to keep queues not empty + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue + ); + + regularTestSetUp(new RegularSetUpConfig.Builder().build()); + + assertTrue(!worker.isQueueEmpty()); + assertEquals(CheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); + + // make RequestQueue.expired return true + when(clock.instant()).thenReturn(Instant.now().plusSeconds(AnomalyDetectorSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); + + // removed the expired queue + worker.maintenance(); + + assertTrue(worker.isQueueEmpty()); + } + + private void maintenanceSetup() { + // do nothing when putting a request to keep queues not empty + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.stats()).thenReturn(new ThreadPoolStats(new ArrayList())); + } + + public void testSettingUpdatable() { + maintenanceSetup(); + + // can host two requests in the queue + worker = new CheckpointReadWorker( + 2000, + 1, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue + ); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + // size not exceeded, thus no effect + worker.maintenance(); + assertTrue(!worker.isQueueEmpty()); + + Settings newSettings = Settings + .builder() + .put(AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT.getKey(), "0.0001") + .build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + // size not exceeded after changing setting + worker.maintenance(); + assertTrue(worker.isQueueEmpty()); + } + + public void testOpenCircuitBreaker() { + maintenanceSetup(); + + ADCircuitBreakerService breaker = mock(ADCircuitBreakerService.class); + when(breaker.isOpen()).thenReturn(true); + + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + breaker, + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue + ); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + + // due to open circuit breaker, removed one request + worker.maintenance(); + assertTrue(!worker.isQueueEmpty()); + + // one request per batch + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE.getKey(), "1").build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + + // enable executing requests + setUpADThreadPool(threadPool); + + // listener returns response back and trigger calls to process extra requests + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + // trigger request execution + worker.put(request3); + assertTrue(worker.isQueueEmpty()); + + // two requests in the queue trigger two batches + verify(checkpoint, times(2)).batchRead(any(), any()); + } + + public void testChangePriority() { + assertEquals(RequestPriority.MEDIUM, request.getPriority()); + RequestPriority newPriority = RequestPriority.HIGH; + request.setPriority(newPriority); + assertEquals(newPriority, request.getPriority()); + } + + public void testDetectorId() { + assertEquals(detectorId, request.getDetectorId()); + String newDetectorId = "456"; + request.setDetectorId(newDetectorId); + assertEquals(newDetectorId, request.getDetectorId()); + } + + @SuppressWarnings("unchecked") + public void testHostException() throws IOException { + String detectorId2 = "456"; + Entity entity4 = Entity.createSingleAttributeEntity(detectorId2, categoryField, "value4"); + EntityFeatureRequest request4 = new EntityFeatureRequest( + Integer.MAX_VALUE, + detectorId2, + RequestPriority.MEDIUM, + entity4, + new double[] { 0 }, + 0 + ); + + AnomalyDetector detector2 = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId2, Arrays.asList(categoryField)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector2)); + return null; + }).when(nodeStateManager).getAnomalyDetector(eq(detectorId2), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(eq(detectorId), any(ActionListener.class)); + + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + CommonName.CHECKPOINT_INDEX_NAME, + "_doc", + entity4.getModelId(detectorId2).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + doThrow(LimitExceededException.class).when(entityCache).hostIfPossible(eq(detector2), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request4); + worker.putAll(requests); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, times(2)).hostIfPossible(any(), any()); + + verify(nodeStateManager, times(1)).setException(eq(detectorId2), any(LimitExceededException.class)); + verify(nodeStateManager, never()).setException(eq(detectorId), any(LimitExceededException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java new file mode 100644 index 000000000..4b8f638ff --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java @@ -0,0 +1,429 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkItemResponse.Failure; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; +import org.opensearch.index.Index; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.index.shard.ShardId; +import org.opensearch.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class CheckpointWriteWorkerTests extends AbstractRateLimitingTest { + CheckpointWriteWorker worker; + + CheckpointDao checkpoint; + ClusterService clusterService; + + ModelState state; + + @Override + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + checkpoint = mock(CheckpointDao.class); + Map checkpointMap = new HashMap<>(); + checkpointMap.put(CheckpointDao.FIELD_MODEL, "a"); + when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); + + // Integer.MAX_VALUE makes a huge heap + worker = new CheckpointWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + CommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + } + + public void testTriggerSave() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "_doc", "id", 1, 1, 1, true) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + public void testTriggerSaveAll() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "_doc", "id", 1, 1, 1, true) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + List> states = new ArrayList<>(); + states.add(state); + worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + /** + * Test that when more requests are coming than concurrency allowed, queues will be + * auto-flushed given enough time. + * @throws InterruptedException when thread.sleep gets interrupted + */ + public void testTriggerAutoFlush() throws InterruptedException { + final CountDownLatch processingLatch = new CountDownLatch(1); + + ExecutorService executorService = mock(ExecutorService.class); + + ThreadPool mockThreadPool = mock(ThreadPool.class); + when(mockThreadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = () -> { + try { + processingLatch.await(100, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.error(e); + assertTrue("Unexpected exception", false); + } + Runnable toInvoke = invocation.getArgument(0); + toInvoke.run(); + }; + // start a new thread so it won't block main test thread's execution + new Thread(runnable).start(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + // make sure permits are released and the next request probe starts + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(checkpoint).batchWrite(any(), any()); + + // Integer.MAX_VALUE makes a huge heap + // create a worker to use mockThreadPool + worker = new CheckpointWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + mockThreadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + CommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + // our concurrency is 2, so first 2 requests cause two batches. And the + // remaining 1 stays in the queue until the 2 concurrent runs finish. + // first 2 batch account for one checkpoint.batchWrite; the remaining one + // calls checkpoint.batchWrite + // CHECKPOINT_WRITE_QUEUE_BATCH_SIZE is the largest batch size + int numberOfRequests = 2 * CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getDefault(Settings.EMPTY) + 1; + for (int i = 0; i < numberOfRequests; i++) { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + worker.write(state, true, RequestPriority.MEDIUM); + } + + // Here, we allow the first 2 pulling batch from queue operations to start. + processingLatch.countDown(); + + // wait until queues get emptied + int waitIntervals = 20; + while (!worker.isQueueEmpty() && waitIntervals-- >= 0) { + Thread.sleep(500); + } + + assertTrue(worker.isQueueEmpty()); + // of requests cause at least one batch. + verify(checkpoint, times(3)).batchWrite(any(), any()); + } + + public void testOverloaded() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(state.getDetectorId()), any(OpenSearchRejectedExecutionException.class)); + } + + public void testRetryException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + // we don't retry checkpoint write + verify(checkpoint, times(1)).batchWrite(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(state.getDetectorId()), any(OpenSearchStatusException.class)); + } + + /** + * Test that we don'd retry failed request + */ + public void testFailedRequest() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new Failure(shardId.getIndexName(), "_doc", "id1", new VersionConflictEngineException(shardId, "id1", "blah")) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + // we don't retry checkpoint write + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyTimeStamp() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.MIN); + worker.write(state, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testTooSoonToSaveSingleWrite() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + worker.write(state, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testTooSoonToSaveWriteAll() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + + List> states = new ArrayList<>(); + states.add(state); + + worker.writeAll(states, detectorId, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyModel() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + when(state.getModel()).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyModelId() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + EntityModel model = mock(EntityModel.class); + when(state.getModel()).thenReturn(model); + when(state.getDetectorId()).thenReturn("1"); + when(state.getModelId()).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyDetectorId() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + EntityModel model = mock(EntityModel.class); + when(state.getModel()).thenReturn(model); + when(state.getDetectorId()).thenReturn(null); + when(state.getModelId()).thenReturn("a"); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorNotAvailableSingleWrite() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorNotAvailableWriteAll() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + List> states = new ArrayList<>(); + states.add(state); + worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorFetchException() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testCheckpointNullSource() throws IOException { + when(checkpoint.toIndexSource(any())).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testCheckpointEmptySource() throws IOException { + Map checkpointMap = new HashMap<>(); + when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testConcurrentModificationException() throws IOException { + doThrow(ConcurrentModificationException.class).when(checkpoint).toIndexSource(any()); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java new file mode 100644 index 000000000..2c5d95199 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java @@ -0,0 +1,181 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; + +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +public class ColdEntityWorkerTests extends AbstractRateLimitingTest { + ClusterService clusterService; + ColdEntityWorker coldWorker; + CheckpointReadWorker readWorker; + EntityFeatureRequest request, request2, invalidRequest; + List requests; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE.getKey(), 1).build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + readWorker = mock(CheckpointReadWorker.class); + + // Integer.MAX_VALUE makes a huge heap + coldWorker = new ColdEntityWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + readWorker, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager + ); + + request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity, new double[] { 0 }, 0); + request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2, new double[] { 0 }, 0); + invalidRequest = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + + requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + requests.add(invalidRequest); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + + TimeValue value = invocation.getArgument(1); + // since we have only 1 request each time + long expectedExecutionPerRequestMilli = 1000 * AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS + .getDefault(Settings.EMPTY); + long delay = value.getMillis(); + assertTrue(delay >= expectedExecutionPerRequestMilli); + assertTrue(delay <= expectedExecutionPerRequestMilli * 2); + return null; + }).when(threadPool).schedule(any(), any(), any()); + } + + public void testPutRequests() { + coldWorker.putAll(requests); + + verify(readWorker, times(2)).putAll(any()); + verify(threadPool, times(2)).schedule(any(), any(), any()); + } + + /** + * We will log a line and continue trying despite exception + */ + public void testCheckpointReadPutException() { + doThrow(RuntimeException.class).when(readWorker).putAll(any()); + coldWorker.putAll(requests); + verify(readWorker, times(2)).putAll(any()); + verify(threadPool, never()).schedule(any(), any(), any()); + } + + /** + * First, invalidRequest gets pulled out and we re-pull; Then we have schedule exception. + * Will not schedule others anymore. + */ + public void testScheduleException() { + doThrow(RuntimeException.class).when(threadPool).schedule(any(), any(), any()); + coldWorker.putAll(requests); + verify(readWorker, times(1)).putAll(any()); + verify(threadPool, times(1)).schedule(any(), any(), any()); + } + + public void testDelay() { + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + // Integer.MAX_VALUE makes a huge heap + coldWorker = new ColdEntityWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + readWorker, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager + ); + + coldWorker.putAll(requests); + + verify(readWorker, times(1)).putAll(any()); + verify(threadPool, never()).schedule(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java new file mode 100644 index 000000000..77dc78015 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -0,0 +1,138 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; +import org.opensearch.rest.RestStatus; + +public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { + ClusterService clusterService; + EntityColdStartWorker worker; + EntityColdStarter entityColdStarter; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + entityColdStarter = mock(EntityColdStarter.class); + + // Integer.MAX_VALUE makes a huge heap + worker = new EntityColdStartWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + entityColdStarter, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager + ); + } + + public void testEmptyModelId() { + EntityRequest request = mock(EntityRequest.class); + when(request.getPriority()).thenReturn(RequestPriority.LOW); + when(request.getModelId()).thenReturn(Optional.empty()); + worker.put(request); + verify(entityColdStarter, never()).trainModel(any(), anyString(), any(), any()); + verify(request, times(1)).getModelId(); + } + + public void testOverloaded() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); + + // 2nd put request won't trigger anything as we are in cooldown mode + worker.put(request); + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + } + + public void testException() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + + // 2nd put request triggers another setException + worker.put(request); + verify(entityColdStarter, times(2)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(2)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java new file mode 100644 index 000000000..dc6f5752f --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java @@ -0,0 +1,213 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class ResultWriteWorkerTests extends AbstractRateLimitingTest { + ResultWriteWorker resultWriteQueue; + ClusterService clusterService; + MultiEntityResultHandler resultHandler; + AnomalyResult detectResult; + + @Override + public void setUp() throws Exception { + super.setUp(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + resultHandler = mock(MultiEntityResultHandler.class); + + resultWriteQueue = new ResultWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + resultHandler, + xContentRegistry(), + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + detectResult = new AnomalyResult( + randomAlphaOfLength(5), + randomAlphaOfLength(5), + 0.8, + Double.NaN, + Double.NaN, + ImmutableList.of(), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + randomAlphaOfLength(5), + null, + null, + CommonValue.NO_SCHEMA_VERSION, + randomAlphaOfLength(5) + ); + + } + + public void testRegular() { + List retryRequests = new ArrayList<>(); + + ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + + ADResultBulkRequest request = new ADResultBulkRequest(); + request.add(detectResult); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(resp); + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult)); + + // the request results one flush + verify(resultHandler, times(1)).flush(any(), any()); + } + + public void testSingleRetryRequest() throws IOException { + List retryRequests = new ArrayList<>(); + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(CommonName.ANOMALY_RESULT_INDEX_ALIAS) + .source(detectResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + retryRequests.add(indexRequest); + } + + ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + + ADResultBulkRequest request = new ADResultBulkRequest(); + request.add(detectResult); + + final AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + listener.onResponse(new ADResultBulkResponse()); + } else { + retried.set(true); + listener.onResponse(resp); + } + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult)); + + // one flush from the original request; and one due to retry + verify(resultHandler, times(2)).flush(any(), any()); + } + + public void testRetryException() { + final AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + listener.onResponse(new ADResultBulkResponse()); + } else { + retried.set(true); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + } + + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult)); + // one flush from the original request; and one due to retry + verify(resultHandler, times(2)).flush(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + } + + public void testOverloaded() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult)); + // one flush from the original request; and one due to retry + verify(resultHandler, times(1)).flush(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 3b51722db..8c2b40b0a 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -86,7 +86,7 @@ public void testAllLegacyOpenDistroSettingsReturned() { public void testAllOpenSearchSettingsReturned() { List> settings = plugin.getSettings(); assertTrue( - "legacy setting must be returned from settings", + "opensearch setting must be returned from settings", settings .containsAll( Arrays @@ -98,7 +98,7 @@ public void testAllOpenSearchSettingsReturned() { AnomalyDetectorSettings.DETECTION_INTERVAL, AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, - AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, AnomalyDetectorSettings.COOLDOWN_MINUTES, AnomalyDetectorSettings.BACKOFF_MINUTES, @@ -109,13 +109,29 @@ public void testAllOpenSearchSettingsReturned() { AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + AnomalyDetectorSettings.INDEX_PRESSURE_HARD_LIMIT, AnomalyDetectorSettings.MAX_PRIMARY_SHARDS, AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, - AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE + AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_SECS, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE ) ) ); @@ -150,10 +166,6 @@ public void testAllLegacyOpenDistroSettingsFallback() { AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY) ); - assertEquals( - AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(Settings.EMPTY), - LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(Settings.EMPTY) - ); assertEquals( AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY) @@ -182,18 +194,10 @@ public void testAllLegacyOpenDistroSettingsFallback() { AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY) ); - assertEquals( - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(Settings.EMPTY), - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(Settings.EMPTY) - ); - assertEquals( - AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(Settings.EMPTY), - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(Settings.EMPTY) - ); - assertEquals( - AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(Settings.EMPTY), - LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(Settings.EMPTY) - ); + // MAX_ENTITIES_FOR_PREVIEW does not use legacy setting + assertEquals(Integer.valueOf(10), AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(Settings.EMPTY)); + // INDEX_PRESSURE_SOFT_LIMIT does not use legacy setting + assertEquals(Float.valueOf(0.6f), AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(Settings.EMPTY)); assertEquals( AnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(Settings.EMPTY) @@ -202,10 +206,6 @@ public void testAllLegacyOpenDistroSettingsFallback() { AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY) ); - assertEquals( - AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(Settings.EMPTY), - LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(Settings.EMPTY) - ); assertEquals( AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(Settings.EMPTY) @@ -256,8 +256,8 @@ public void testSettingsGetValue() { assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(94)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(12)); - settings = Settings.builder().put("plugins.anomaly_detection.ad_result_history_max_docs", 93).build(); - assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(93)); + settings = Settings.builder().put("plugins.anomaly_detection.ad_result_history_max_docs_per_shard", 93).build(); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.get(settings), Long.valueOf(93)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(250000000)); settings = Settings @@ -316,7 +316,6 @@ public void testSettingsGetValue() { assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(settings), Integer.valueOf(10)); settings = Settings.builder().put("plugins.anomaly_detection.max_cache_miss_handling_per_second", 79).build(); - assertEquals(AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(79)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(100)); settings = Settings.builder().put("plugins.anomaly_detection.max_batch_task_per_node", 78).build(); @@ -356,7 +355,6 @@ public void testSettingsGetValueWithLegacyFallback() { .put("opendistro.anomaly_detection.max_retry_for_end_run_exception", 15) .put("opendistro.anomaly_detection.filter_by_backend_roles", true) .put("opendistro.anomaly_detection.model_max_size_percent", 0.6D) - .put("opendistro.anomaly_detection.max_entities_per_query", 18) .put("opendistro.anomaly_detection.max_entities_for_preview", 19) .put("opendistro.anomaly_detection.index_pressure_soft_limit", 20F) .put("opendistro.anomaly_detection.max_primary_shards", 21) @@ -374,7 +372,8 @@ public void testSettingsGetValueWithLegacyFallback() { assertEquals(AnomalyDetectorSettings.DETECTION_INTERVAL.get(settings), TimeValue.timeValueMinutes(5)); assertEquals(AnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(settings), TimeValue.timeValueMinutes(6)); assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(7)); - assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(8L)); + // AD_RESULT_HISTORY_MAX_DOCS is removed in the new release + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(8L)); assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(9)); assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(10)); assertEquals(AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(11)); @@ -384,11 +383,13 @@ public void testSettingsGetValueWithLegacyFallback() { assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(15)); assertEquals(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); assertEquals(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.6D)); - assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(18)); - assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(19)); - assertEquals(AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(settings), Float.valueOf(20F)); + // MAX_ENTITIES_FOR_PREVIEW uses default instead of legacy fallback + assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(10)); + // INDEX_PRESSURE_SOFT_LIMIT uses default instead of legacy fallback + assertEquals(AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(settings), Float.valueOf(0.6F)); assertEquals(AnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(settings), Integer.valueOf(21)); - assertEquals(AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(22)); + // MAX_CACHE_MISS_HANDLING_PER_SECOND is removed in the new release + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(22)); assertEquals(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(settings), Integer.valueOf(23)); assertEquals(AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings), Integer.valueOf(24)); assertEquals(AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(settings), Integer.valueOf(25)); @@ -412,9 +413,6 @@ public void testSettingsGetValueWithLegacyFallback() { LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION, LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, - LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java index e423ea83d..0fe6a8fbd 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java @@ -56,6 +56,7 @@ import org.opensearch.test.OpenSearchTestCase; import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; import com.amazon.randomcutforest.RandomCutForest; @@ -96,8 +97,8 @@ public void setup() { when(modelManager.getAllModels()).thenReturn(modelsInformation); - ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); - ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); EntityCache cache = mock(EntityCache.class); @@ -124,7 +125,7 @@ public void setup() { } }; - adStats = new ADStats(indexUtils, modelManager, statsMap); + adStats = new ADStats(statsMap); } @Test diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index b9eeb790f..51b30f700 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -51,6 +51,7 @@ import org.opensearch.test.OpenSearchTestCase; import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; import com.amazon.randomcutforest.RandomCutForest; @@ -87,8 +88,8 @@ public void setup() { when(modelManager.getAllModels()).thenReturn(expectedResults); - ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); - ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); EntityCache cache = mock(EntityCache.class); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index 2f6be3e45..ba0da64fa 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -28,7 +28,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -88,7 +87,7 @@ public void tearDown() throws Exception { } public void testPutTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -104,7 +103,7 @@ public void testPutTask() throws IOException { } public void testPutDuplicateTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask1 = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask1); assertEquals(1, adTaskCacheManager.size()); @@ -125,7 +124,7 @@ public void testPutDuplicateTask() throws IOException { } public void testPutTaskWithMemoryExceedLimit() { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); LimitExceededException exception = expectThrows( LimitExceededException.class, () -> adTaskCacheManager.add(TestHelpers.randomAdTask()) @@ -134,7 +133,7 @@ public void testPutTaskWithMemoryExceedLimit() { } public void testThresholdModelTrained() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -147,7 +146,7 @@ public void testThresholdModelTrained() throws IOException { } public void testCancel() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -174,7 +173,7 @@ public void testRemoveTaskWhichNotExist() { } public void testExceedRunningTaskLimit() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); adTaskCacheManager.add(TestHelpers.randomAdTask()); adTaskCacheManager.add(TestHelpers.randomAdTask()); assertEquals(2, adTaskCacheManager.size()); diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java index 265450e54..e7a772456 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java @@ -31,13 +31,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; +import java.util.Locale; import org.junit.After; import org.junit.AfterClass; @@ -56,7 +54,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexingPressure; import org.opensearch.transport.TransportService; @@ -87,14 +84,12 @@ public void setUp() throws Exception { .put(IndexingPressure.MAX_INDEXING_BYTES.getKey(), "1KB") .put(AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.getKey(), 0.8) .build(); - setupTestNodes(settings); + + // without register these settings, the constructor of ADResultBulkTransportAction cannot invoke update consumer + setupTestNodes(AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, AnomalyDetectorSettings.INDEX_PRESSURE_HARD_LIMIT); transportService = testNodes[0].transportService; - clusterService = spy(testNodes[0].clusterService); - ClusterSettings clusterSettings = new ClusterSettings( - settings, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT))) - ); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + clusterService = testNodes[0].clusterService; + ActionFilters actionFilters = mock(ActionFilters.class); indexingPressure = mock(IndexingPressure.class); @@ -116,8 +111,8 @@ public void testSendAll() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(0.8d, 0d)); - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(8d, 0.2d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -133,7 +128,7 @@ public void testSendAll() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -146,8 +141,8 @@ public void testSendPartial() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(24L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(0.8d, 0d)); - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(8d, 0.2d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -163,7 +158,7 @@ public void testSendPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -171,16 +166,16 @@ public void testSendPartial() { @SuppressWarnings("unchecked") public void testSendRandomPartial() { - // 400 + 421 > 1024 * 0.8. 1024 is 1KB, our INDEX_PRESSURE_SOFT_LIMIT + // 1024 * 0.9 > 400 + 421 > 1024 * 0.6. 1024 is 1KB, our INDEX_PRESSURE_SOFT_LIMIT when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(400L); when(indexingPressure.getCurrentReplicaBytes()).thenReturn(421L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); for (int i = 0; i < 1000; i++) { - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(0.8d, 0d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d)); } - originalRequest.add(TestHelpers.randomMultiEntityAnomalyDetectResult(8d, 0.2d)); + originalRequest.add(TestHelpers.randomHCADAnomalyDetectResult(8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -194,12 +189,12 @@ public void testSendRandomPartial() { int size = request.requests().size(); assertTrue(1 < size); // at least 1 half should be removed - assertTrue(String.format("size is actually %d", size), size < 500); + assertTrue(String.format(Locale.ROOT, "size is actually %d", size), size < 500); listener.onResponse(null); return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -207,8 +202,8 @@ public void testSendRandomPartial() { public void testSerialzationRequest() throws IOException { ADResultBulkRequest request = new ADResultBulkRequest(); - request.add(TestHelpers.randomMultiEntityAnomalyDetectResult(0.8d, 0d)); - request.add(TestHelpers.randomMultiEntityAnomalyDetectResult(8d, 0.2d)); + request.add(TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d)); + request.add(TestHelpers.randomHCADAnomalyDetectResult(8d, 0.2d)); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java index ad95dc996..e853913b4 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java @@ -105,7 +105,7 @@ public void setUp() throws Exception { } }; - adStats = new ADStats(indexUtils, modelManager, statsMap); + adStats = new ADStats(statsMap); JvmService jvmService = mock(JvmService.class); JvmStats jvmStats = mock(JvmStats.class); JvmStats.Mem mem = mock(JvmStats.Mem.class); diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index a463f80ca..16b7e6019 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -28,15 +28,20 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import java.io.IOException; +import java.time.Clock; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.TreeMap; import java.util.stream.Collectors; import org.junit.Before; @@ -44,6 +49,10 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.CommonName; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.stats.StatNames; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -57,6 +66,9 @@ import test.org.opensearch.ad.util.JsonDeserializer; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; + public class ADStatsTests extends OpenSearchTestCase { String node1, nodeName1, clusterName; Map clusterStats; @@ -102,7 +114,7 @@ public void testADStatsNodeRequest() throws IOException { } @Test - public void testADStatsNodeResponse() throws IOException, JsonPathNotFoundException { + public void testSimpleADStatsNodeResponse() throws IOException, JsonPathNotFoundException { Map stats = new HashMap() { { put("testKey", "testValue"); @@ -127,6 +139,72 @@ public void testADStatsNodeResponse() throws IOException, JsonPathNotFoundExcept } } + /** + * Test we can serialize stats with entity + * @throws IOException when writeTo and toXContent have errors. + * @throws JsonPathNotFoundException when json deserialization cannot find a path + */ + @Test + public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotFoundException { + TreeMap attributes = new TreeMap<>(); + String name1 = "a"; + String name2 = "b"; + String val1 = "a1"; + String val2 = "a2"; + attributes.put(name1, val1); + attributes.put(name2, val2); + String detectorId = "detectorId"; + Entity entity = Entity.createEntityFromOrderedMap(detectorId, attributes); + EntityModel entityModel = new EntityModel(entity, null, null, null); + Clock clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + ModelState state = new ModelState( + entityModel, + entity.getModelId(detectorId).get(), + detectorId, + "entity", + clock, + 0.1f + ); + Map stats = state.getModelStateAsMap(); + + // Test serialization + ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + BytesStreamOutput output = new BytesStreamOutput(); + adStatsNodeResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + adStatsNodeResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + for (Map.Entry stat : stats.entrySet()) { + if (stat.getKey().equals(ModelState.LAST_CHECKPOINT_TIME_KEY) || stat.getKey().equals(ModelState.LAST_USED_TIME_KEY)) { + assertEquals("toXContent does not work", JsonDeserializer.getLongValue(json, stat.getKey()), stat.getValue()); + } else if (stat.getKey().equals(CommonName.ENTITY_KEY)) { + JsonArray array = JsonDeserializer.getArrayValue(json, stat.getKey()); + assertEquals(2, array.size()); + for (int i = 0; i < 2; i++) { + JsonElement element = array.get(i); + String entityName = JsonDeserializer.getChildNode(element, Entity.ATTRIBUTE_NAME_FIELD).getAsString(); + String entityValue = JsonDeserializer.getChildNode(element, Entity.ATTRIBUTE_VALUE_FIELD).getAsString(); + + assertTrue(entityName.equals(name1) || entityName.equals(name2)); + if (entityName.equals(name1)) { + assertEquals(val1, entityValue); + } else { + assertEquals(val2, entityValue); + } + } + } else { + assertEquals("toXContent does not work", JsonDeserializer.getTextValue(json, stat.getKey()), stat.getValue()); + } + } + } + @Test public void testADStatsRequest() throws IOException { List allStats = Arrays.stream(StatNames.values()).map(StatNames::getName).collect(Collectors.toList()); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 545047eb3..7ec4efedd 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -52,7 +52,6 @@ import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import java.io.IOException; -import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -90,10 +89,8 @@ import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.common.exception.ResourceNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.feature.SinglePointFeatures; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelPartitioner; @@ -103,13 +100,11 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.FeatureData; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.stats.StatNames; import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -124,6 +119,7 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.index.Index; @@ -146,7 +142,7 @@ import com.google.gson.JsonElement; public class AnomalyResultTests extends AbstractADTest { - private static Settings settings = Settings.EMPTY; + private Settings settings; private TransportService transportService; private ClusterService clusterService; private NodeStateManager stateManager; @@ -165,7 +161,6 @@ public class AnomalyResultTests extends AbstractADTest { private ADCircuitBreakerService adCircuitBreakerService; private ADStats adStats; private int partitionNum; - private SearchFeatureDao searchFeatureDao; @BeforeClass public static void setUpBeforeClass() { @@ -183,10 +178,13 @@ public static void tearDownAfterClass() { public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); - setupTestNodes(settings); + + setupTestNodes(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.PAGE_SIZE); transportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; + settings = clusterService.getSettings(); + stateManager = mock(NodeStateManager.class); // return 2 RCF partitions partitionNum = 2; @@ -206,7 +204,6 @@ public void setUp() throws Exception { adID = "123"; when(detector.getDetectorId()).thenReturn(adID); when(detector.getCategoryField()).thenReturn(null); - // when(detector.getDetectorId()).thenReturn("testDetectorId"); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); @@ -273,11 +270,6 @@ public void setUp() throws Exception { }).when(client).index(any(), any()); indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - Clock clock = mock(Clock.class); - Throttler throttler = new Throttler(clock); - ThreadPool threadpool = mock(ThreadPool.class); - ClientUtil clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, threadpool); - IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); Map> statsMap = new HashMap>() { { @@ -288,7 +280,7 @@ public void setUp() throws Exception { } }; - adStats = new ADStats(indexUtils, normalModelManager, statsMap); + adStats = new ADStats(statsMap); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -306,8 +298,6 @@ public void setUp() throws Exception { return null; }).when(client).get(any(), any()); - - searchFeatureDao = mock(SearchFeatureDao.class); } @Override @@ -349,7 +339,7 @@ public void testNormal() throws IOException { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -434,7 +424,12 @@ public void sendRequest( // need to close nodes created in the setUp nodes and create new nodes // for the failure interceptor. Otherwise, we will get thread leak error. tearDownTestNodes(); - setupTestNodes(Settings.EMPTY, failureTransportInterceptor); + setupTestNodes( + failureTransportInterceptor, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE + ); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); @@ -464,7 +459,7 @@ public void sendRequest( adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -498,12 +493,12 @@ public void testNormalColdStartRemoteException() { } public void testNullPointerExceptionWhenRCF() { - noModelExceptionTemplate(new NullPointerException(), adID, EndRunException.class, AnomalyResultTransportAction.BUG_RESPONSE); + noModelExceptionTemplate(new NullPointerException(), adID, EndRunException.class, CommonErrorMessages.BUG_RESPONSE); } public void testADExceptionWhenColdStart() { String error = "blah"; - when(stateManager.fetchColdStartException(any(String.class))).thenReturn(Optional.of(new AnomalyDetectionException(adID, error))); + when(stateManager.fetchExceptionAndClear(any(String.class))).thenReturn(Optional.of(new AnomalyDetectionException(adID, error))); noModelExceptionTemplate(new ResourceNotFoundException(adID, ""), adID, AnomalyDetectionException.class, error); } @@ -516,7 +511,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { .when(rcfManager) .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); - when(stateManager.fetchColdStartException(any(String.class))) + when(stateManager.fetchExceptionAndClear(any(String.class))) .thenReturn(Optional.of(new LimitExceededException(adID, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))); // These constructors register handler in transport service @@ -538,7 +533,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -575,7 +570,7 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -644,7 +639,12 @@ public void sendRequest( // need to close nodes created in the setUp nodes and create new nodes // for the failure interceptor. Otherwise, we will get thread leak error. tearDownTestNodes(); - setupTestNodes(Settings.EMPTY, failureTransportInterceptor); + setupTestNodes( + failureTransportInterceptor, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE + ); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); @@ -671,7 +671,7 @@ public void sendRequest( adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -683,7 +683,7 @@ public void sendRequest( } public void testThresholdException() { - thresholdExceptionTestTemplate(new NullPointerException(), adID, EndRunException.class, AnomalyResultTransportAction.BUG_RESPONSE); + thresholdExceptionTestTemplate(new NullPointerException(), adID, EndRunException.class, CommonErrorMessages.BUG_RESPONSE); } public void testCircuitBreaker() { @@ -710,7 +710,7 @@ public void testCircuitBreaker() { breakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -777,7 +777,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -845,7 +845,7 @@ public void testMute() { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); @@ -880,7 +880,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); TransportRequestOptions option = TransportRequestOptions @@ -993,9 +993,9 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals(JsonDeserializer.getTextValue(json, CommonMessageAttributes.ID_JSON_KEY), request.getAdID()); - assertEquals(JsonDeserializer.getLongValue(json, CommonMessageAttributes.START_JSON_KEY), request.getStart()); - assertEquals(JsonDeserializer.getLongValue(json, CommonMessageAttributes.END_JSON_KEY), request.getEnd()); + assertEquals(JsonDeserializer.getTextValue(json, CommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), request.getStart()); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), request.getEnd()); } public void testEmptyID() { @@ -1035,7 +1035,7 @@ public void testOnFailureNull() throws IOException { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null, 1 @@ -1092,14 +1092,14 @@ public void testColdStartNoTrainingData() throws Exception { adCircuitBreakerService, adStats, mockThreadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(1)).setException(eq(adID), any(EndRunException.class)); verify(stateManager, times(1)).markColdStartRunning(eq(adID)); } @@ -1129,14 +1129,14 @@ public void testConcurrentColdStart() throws Exception { adCircuitBreakerService, adStats, mockThreadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - verify(stateManager, never()).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, never()).setException(eq(adID), any(EndRunException.class)); verify(stateManager, never()).markColdStartRunning(eq(adID)); } @@ -1172,14 +1172,14 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { adCircuitBreakerService, adStats, mockThreadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(InternalFailure.class)); + verify(stateManager, times(1)).setException(eq(adID), any(InternalFailure.class)); verify(stateManager, times(1)).markColdStartRunning(eq(adID)); } @@ -1215,14 +1215,14 @@ public void testColdStartIllegalArgumentException() throws Exception { adCircuitBreakerService, adStats, mockThreadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); - verify(stateManager, times(1)).setLastColdStartException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(1)).setException(eq(adID), any(EndRunException.class)); verify(stateManager, times(1)).markColdStartRunning(eq(adID)); } @@ -1265,7 +1265,7 @@ public void featureTestTemplate(FeatureTestMode mode) throws IOException { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1351,7 +1351,7 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1398,7 +1398,7 @@ public void testNullRCFResult() { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null, 1 @@ -1430,7 +1430,7 @@ public void testAllFeaturesDisabled() throws IOException { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1453,7 +1453,7 @@ public void testEndRunDueToNoTrainingData() { return null; }).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); - when(stateManager.fetchColdStartException(any(String.class))) + when(stateManager.fetchExceptionAndClear(any(String.class))) .thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false))); doAnswer(invocation -> { @@ -1487,7 +1487,7 @@ public void testEndRunDueToNoTrainingData() { adCircuitBreakerService, adStats, mockThreadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); @@ -1521,7 +1521,7 @@ public void testRCFNodeCircuitBreakerBroken() { adCircuitBreakerService, adStats, threadPool, - searchFeatureDao + NamedXContentRegistry.EMPTY ); AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 8d045893f..3bb811406 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -34,6 +34,8 @@ import java.util.List; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.junit.Before; import org.opensearch.action.get.GetResponse; import org.opensearch.ad.ADIntegTestCase; @@ -50,6 +52,8 @@ import com.google.common.collect.ImmutableMap; public class AnomalyResultTransportActionTests extends ADIntegTestCase { + private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportActionTests.class); + private String testIndex; private Instant testDataTimeStamp; private long start; @@ -144,57 +148,57 @@ public void testFeatureWithCardinalityOfTextField() throws IOException { public void testFeatureQueryWithTermsAggregationForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"terms\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Failed to parse aggregation"); + assertErrorMessage(adId, "Failed to parse aggregation", true); } public void testFeatureWithSumOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } public void testFeatureWithSumOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]", true); } public void testFeatureWithMaxOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } public void testFeatureWithMaxOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]", true); } public void testFeatureWithMinOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } public void testFeatureWithMinOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]", true); } public void testFeatureWithAvgOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } public void testFeatureWithAvgOfTypeFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"type\"}}}", true); - assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]", true); } public void testFeatureWithCountOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"value_count\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } public void testFeatureWithCardinalityOfTextFieldForHCDetector() throws IOException { String adId = createDetectorWithFeatureAgg("{\"test\":{\"cardinality\":{\"field\":\"message\"}}}", true); - assertErrorMessage(adId, "Text fields are not optimised for operations"); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); } private String createDetectorWithFeatureAgg(String aggQuery) throws IOException { @@ -253,12 +257,33 @@ private AnomalyDetector randomHCDetector(List indices, List fea ); } - private void assertErrorMessage(String adId, String errorMessage) { + private void assertErrorMessage(String adId, String errorMessage, boolean hcDetector) { AnomalyResultRequest resultRequest = new AnomalyResultRequest(adId, start, end); - RuntimeException e = expectThrowsAnyOf( - ImmutableList.of(NotSerializableExceptionWrapper.class, AnomalyDetectionException.class), - () -> client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000) - ); - assertTrue(e.getMessage().contains(errorMessage)); + // wait at most 20 seconds + int numberofTries = 40; + Exception e = null; + if (hcDetector) { + while (numberofTries-- > 0) { + try { + // HCAD records failures asynchronously. Before a failure is recorded, HCAD returns immediately without failure. + client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000); + Thread.sleep(500); + } catch (Exception exp) { + e = exp; + break; + } + } + } else { + e = expectThrowsAnyOf( + ImmutableList.of(NotSerializableExceptionWrapper.class, AnomalyDetectionException.class), + () -> client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000) + ); + } + + assertTrue("Unexpected error: " + e.getMessage(), e.getMessage().contains(errorMessage)); + } + + private void assertErrorMessage(String adId, String errorMessage) { + assertErrorMessage(adId, errorMessage, false); } } diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java index 42ebdc8da..1c2dad0a4 100644 --- a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java @@ -44,6 +44,7 @@ import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.ModelManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -83,6 +84,7 @@ public void setUp() throws Exception { FeatureManager featureManager = mock(FeatureManager.class); CacheProvider cacheProvider = mock(CacheProvider.class); EntityCache entityCache = mock(EntityCache.class); + EntityColdStarter entityColdStarter = mock(EntityColdStarter.class); when(cacheProvider.get()).thenReturn(entityCache); action = new CronTransportAction( @@ -93,7 +95,8 @@ public void setUp() throws Exception { tarnsportStatemanager, modelManager, featureManager, - cacheProvider + cacheProvider, + entityColdStarter ); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index c35033e70..f4f8f7cb9 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -58,7 +58,7 @@ import org.opensearch.ad.AbstractADTest; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -141,7 +141,7 @@ public void setUp() throws Exception { transportService = mock(TransportService.class); threadPool = mock(ThreadPool.class); actionFilters = mock(ActionFilters.class); - Settings settings = Settings.builder().put("opendistro.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)).build(); + Settings settings = Settings.builder().put("plugins.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)).build(); task = mock(Task.class); when(task.getId()).thenReturn(1000L); client = mock(Client.class); @@ -201,7 +201,7 @@ public void testJsonRequestTemplate(R request, Supplier modelSizeMap = new HashMap<>(); + modelSizeMap.put(modelId, modelSize); + when(cache.getModelSize(anyString())).thenReturn(modelSizeMap); when(cacheProvider.get()).thenReturn(cache); - action = new EntityProfileTransportAction( - actionFilters, - transportService, - settings, - modelManager, - hashRing, - clusterService, - cacheProvider - ); + action = new EntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); future = new PlainActionFuture<>(); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); - request = new EntityProfileRequest(detectorId, entityValue, state); + entity = Entity.createSingleAttributeEntity(detectorId, categoryName, entityValue); + + request = new EntityProfileRequest(detectorId, entity, state); normalTransportInterceptor = new TransportInterceptor() { @Override @@ -275,7 +273,6 @@ private void registerHandler(FakeNode node) { new ActionFilters(Collections.emptySet()), node.transportService, Settings.EMPTY, - modelManager, hashRing, node.clusterService, cacheProvider @@ -304,21 +301,16 @@ public void testAllHit() { when(hashRing.getOwningNode(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); - request = new EntityProfileRequest(detectorId, entityValue, all); + request = new EntityProfileRequest(detectorId, entity, all); action.doExecute(task, request, future); - EntityProfileResponse expectedResponse = new EntityProfileResponse( - isActive, - lastActiveTimestamp, - updates, - new ModelProfile(modelId, modelSize, nodeId) - ); + EntityProfileResponse expectedResponse = new EntityProfileResponse(isActive, lastActiveTimestamp, updates, null); EntityProfileResponse response = future.actionGet(20_000); assertEquals(expectedResponse, response); } public void testGetRemoteUpdateResponse() { - setupTestNodes(Settings.EMPTY, normalTransportInterceptor); + setupTestNodes(normalTransportInterceptor); try { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; @@ -327,7 +319,6 @@ public void testGetRemoteUpdateResponse() { actionFilters, realTransportService, settings, - modelManager, hashRing, clusterService, cacheProvider @@ -348,7 +339,7 @@ public void testGetRemoteUpdateResponse() { } public void testGetRemoteFailureResponse() { - setupTestNodes(Settings.EMPTY, failureTransportInterceptor); + setupTestNodes(failureTransportInterceptor); try { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; @@ -357,7 +348,6 @@ public void testGetRemoteFailureResponse() { actionFilters, realTransportService, settings, - modelManager, hashRing, clusterService, cacheProvider @@ -378,17 +368,17 @@ public void testResponseToXContent() throws IOException, JsonPathNotFoundExcepti long lastActiveTimestamp = 10L; EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); builder.setLastActiveMs(lastActiveTimestamp).build(); - builder.setModelProfile(new ModelProfile(modelId, modelSize, nodeId)); + builder.setModelProfile(new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize))); EntityProfileResponse response = builder.build(); String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); assertEquals(lastActiveTimestamp, JsonDeserializer.getLongValue(json, EntityProfileResponse.LAST_ACTIVE_TS)); - assertEquals(modelSize, JsonDeserializer.getChildNode(json, CommonName.MODEL, ModelProfile.MODEL_SIZE_IN_BYTES).getAsLong()); + assertEquals(modelSize, JsonDeserializer.getChildNode(json, CommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); } public void testSerialzationResponse() throws IOException { EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); builder.setLastActiveMs(lastActiveTimestamp).build(); - ModelProfile model = new ModelProfile(modelId, modelSize, nodeId); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); builder.setModelProfile(model); EntityProfileResponse response = builder.build(); @@ -404,7 +394,7 @@ public void testSerialzationResponse() throws IOException { public void testResponseHashCodeEquals() { EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); builder.setLastActiveMs(lastActiveTimestamp).build(); - ModelProfile model = new ModelProfile(modelId, modelSize, nodeId); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); builder.setModelProfile(model); EntityProfileResponse response = builder.build(); diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index 45486a87e..c74ab847c 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -43,18 +43,24 @@ import java.io.IOException; import java.time.Clock; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.AnomalyDetectorJobRunnerTests; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.breaker.ADCircuitBreakerService; @@ -64,17 +70,20 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.constant.CommonValue; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.ResultWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; @@ -85,6 +94,11 @@ import org.opensearch.transport.TransportService; import test.org.opensearch.ad.util.JsonDeserializer; +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; public class EntityResultTransportActionTests extends AbstractADTest { EntityResultTransportAction entityResult; @@ -92,7 +106,6 @@ public class EntityResultTransportActionTests extends AbstractADTest { TransportService transportService; ModelManager manager; ADCircuitBreakerService adCircuitBreakerService; - MultiEntityResultHandler anomalyResultHandler; CheckpointDao checkpointDao; CacheProvider provider; EntityCache entityCache; @@ -105,13 +118,31 @@ public class EntityResultTransportActionTests extends AbstractADTest { AnomalyDetector detector; String cacheMissEntity; String cacheHitEntity; + Entity cacheHitEntityObj; + Entity cacheMissEntityObj; long start; long end; - Map entities; + Map entities; double[] cacheMissData; double[] cacheHitData; String tooLongEntity; double[] tooLongData; + ResultWriteWorker resultWriteQueue; + CheckpointReadWorker checkpointReadQueue; + int minSamples; + Instant now; + EntityColdStarter coldStarter; + ColdEntityWorker coldEntityQueue; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } @SuppressWarnings("unchecked") @Override @@ -124,41 +155,47 @@ public void setUp() throws Exception { adCircuitBreakerService = mock(ADCircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - anomalyResultHandler = mock(MultiEntityResultHandler.class); checkpointDao = mock(CheckpointDao.class); detectorId = "123"; entities = new HashMap<>(); - cacheMissEntity = "0.0.0.1"; - cacheMissData = new double[] { 0.1 }; - cacheHitEntity = "0.0.0.2"; - cacheHitData = new double[] { 0.2 }; - entities.put(cacheMissEntity, cacheMissData); - entities.put(cacheHitEntity, cacheHitData); - tooLongEntity = randomAlphaOfLength(AnomalyDetectorSettings.MAX_ENTITY_LENGTH + 1); - tooLongData = new double[] { 0.3 }; - entities.put(tooLongEntity, tooLongData); start = 10L; end = 20L; request = new EntityResultRequest(detectorId, entities, start, end); - manager = mock(ModelManager.class); - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - // return entity name - return args[1]; - }).when(manager).getEntityModelId(anyString(), anyString()); - when(manager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) - .thenReturn(new ThresholdingResult(1, 1, 1)); + clock = mock(Clock.class); + now = Instant.now(); + when(clock.instant()).thenReturn(now); + + manager = new ModelManager( + null, + null, + null, + clock, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + null, + 0, + null, + null, + mock(EntityColdStarter.class), + null, + null, + null + ); provider = mock(CacheProvider.class); entityCache = mock(EntityCache.class); when(provider.get()).thenReturn(entityCache); - when(entityCache.get(eq(cacheMissEntity), any(), any(), anyString())).thenReturn(null); - - ModelState state = mock(ModelState.class); - when(entityCache.get(eq(cacheHitEntity), any(), any(), anyString())).thenReturn(state); String field = "a"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -168,27 +205,59 @@ public void setUp() throws Exception { listener.onResponse(Optional.of(detector)); return null; }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); - when(stateManager.getLastIndexThrottledTime()).thenReturn(Instant.MIN); + + cacheMissEntity = "0.0.0.1"; + cacheMissData = new double[] { 0.1 }; + cacheHitEntity = "0.0.0.2"; + cacheHitData = new double[] { 0.2 }; + cacheMissEntityObj = Entity.createSingleAttributeEntity(detectorId, detector.getCategoryField().get(0), cacheMissEntity); + entities.put(cacheMissEntityObj, cacheMissData); + cacheHitEntityObj = Entity.createSingleAttributeEntity(detectorId, detector.getCategoryField().get(0), cacheHitEntity); + entities.put(cacheHitEntityObj, cacheHitData); + tooLongEntity = randomAlphaOfLength(AnomalyDetectorSettings.MAX_ENTITY_LENGTH + 1); + tooLongData = new double[] { 0.3 }; + entities.put(Entity.createSingleAttributeEntity(detectorId, detector.getCategoryField().get(0), tooLongEntity), tooLongData); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null); + when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); + + List coldEntities = new ArrayList<>(); + coldEntities.add(cacheMissEntityObj); + when(entityCache.selectUpdateCandidate(any(), anyString(), any())).thenReturn(Pair.of(new ArrayList<>(), coldEntities)); settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); - clock = mock(Clock.class); - when(clock.instant()).thenReturn(Instant.now()); AnomalyDetectionIndices indexUtil = mock(AnomalyDetectionIndices.class); when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION); + resultWriteQueue = mock(ResultWriteWorker.class); + checkpointReadQueue = mock(CheckpointReadWorker.class); + + minSamples = 1; + + coldStarter = mock(EntityColdStarter.class); + + doAnswer(invocation -> { + ModelState modelState = invocation.getArgument(0); + modelState.getModel().clear(); + return null; + }).when(coldStarter).trainModelFromExistingSamples(any()); + + coldEntityQueue = mock(ColdEntityWorker.class); + entityResult = new EntityResultTransportAction( actionFilters, transportService, manager, adCircuitBreakerService, - anomalyResultHandler, - checkpointDao, provider, stateManager, - settings, - clock, - indexUtil + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool ); // timeout in 60 seconds @@ -211,7 +280,7 @@ public void testNormal() { future.actionGet(timeoutMs); - verify(anomalyResultHandler, times(1)).flush(any(), any()); + verify(resultWriteQueue, times(1)).put(any()); } // test get detector failure @@ -230,25 +299,10 @@ public void testFailtoGetDetector() { expectThrows(EndRunException.class, () -> future.actionGet(timeoutMs)); } - // test index pressure high, anomaly grade is 0 - public void testIndexPressureHigh() { - when(manager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) - .thenReturn(new ThresholdingResult(0, 1, 1)); - when(stateManager.getLastIndexThrottledTime()).thenReturn(Instant.now()); - - PlainActionFuture future = PlainActionFuture.newFuture(); - - entityResult.doExecute(null, request, future); - - future.actionGet(timeoutMs); - - verify(anomalyResultHandler, never()).flush(any(), any()); - } - // test rcf score is 0 - public void testNotInitialized() { - when(manager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) - .thenReturn(new ThresholdingResult(0, 0, 0)); + public void testNoResultsToSave() { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); + when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -256,7 +310,7 @@ public void testNotInitialized() { future.actionGet(timeoutMs); - verify(anomalyResultHandler, never()).flush(any(), any()); + verify(resultWriteQueue, never()).put(any()); } public void testSerialzationRequest() throws IOException { @@ -299,11 +353,29 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals(JsonDeserializer.getTextValue(json, CommonMessageAttributes.ID_JSON_KEY), detectorId); - assertEquals(JsonDeserializer.getLongValue(json, CommonMessageAttributes.START_JSON_KEY), start); - assertEquals(JsonDeserializer.getLongValue(json, CommonMessageAttributes.END_JSON_KEY), end); - assertEquals(0, Double.compare(JsonDeserializer.getArrayValue(json, cacheMissEntity).get(0).getAsDouble(), cacheMissData[0])); - assertEquals(0, Double.compare(JsonDeserializer.getArrayValue(json, cacheHitEntity).get(0).getAsDouble(), cacheHitData[0])); - assertEquals(0, Double.compare(JsonDeserializer.getArrayValue(json, tooLongEntity).get(0).getAsDouble(), tooLongData[0])); + assertEquals(JsonDeserializer.getTextValue(json, CommonName.ID_JSON_KEY), detectorId); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), start); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), end); + JsonArray array = JsonDeserializer.getArrayValue(json, CommonName.ENTITIES_JSON_KEY); + assertEquals(3, array.size()); + for (int i = 0; i < 3; i++) { + JsonElement element = array.get(i); + JsonElement entity = JsonDeserializer.getChildNode(element, CommonName.ENTITY_KEY); + JsonArray entityArray = entity.getAsJsonArray(); + assertEquals(1, entityArray.size()); + + JsonElement attribute = entityArray.get(0); + String entityValue = JsonDeserializer.getChildNode(attribute, Entity.ATTRIBUTE_VALUE_FIELD).getAsString(); + + double value = JsonDeserializer.getChildNode(element, CommonName.VALUE_JSON_KEY).getAsJsonArray().get(0).getAsDouble(); + + if (entityValue.equals(cacheMissEntity)) { + assertEquals(0, Double.compare(cacheMissData[0], value)); + } else if (entityValue.equals(cacheHitEntity)) { + assertEquals(0, Double.compare(cacheHitData[0], value)); + } else { + assertEquals(0, Double.compare(tooLongData[0], value)); + } + } } } diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index 8d41710d5..8b8105aa7 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -47,6 +47,7 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.AbstractADTest; import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; @@ -66,10 +67,12 @@ public class GetAnomalyDetectorTests extends AbstractADTest { private GetAnomalyDetectorRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; + private String categoryField = "categoryField"; private String typeStr; private String rawPath; private PlainActionFuture future; private ADTaskManager adTaskManager; + private Entity entity; @BeforeClass public static void setUpBeforeClass() { @@ -120,6 +123,8 @@ public void setUp() throws Exception { xContentRegistry(), adTaskManager ); + + entity = Entity.createSingleAttributeEntity(detectorId, categoryField, entityValue); } public void testInvalidRequest() throws IOException { @@ -127,7 +132,7 @@ public void testInvalidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entityValue); + request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -152,7 +157,7 @@ public void testValidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entityValue); + request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 33cdca1ec..6f46029d3 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -43,10 +43,13 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.EntityProfile; +import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; @@ -68,10 +71,14 @@ import com.google.common.collect.ImmutableMap; public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { + private GetAnomalyDetectorTransportAction action; private Task task; private ActionListener response; private ADTaskManager adTaskManager; + private Entity entity; + private String categoryField; + private String categoryValue; @Override @Before @@ -105,6 +112,9 @@ public void onResponse(GetAnomalyDetectorResponse getResponse) { @Override public void onFailure(Exception e) {} }; + categoryField = "catField"; + categoryValue = "app-0"; + entity = Entity.createSingleAttributeEntity("detectorId", categoryField, categoryValue); } @Override @@ -150,7 +160,7 @@ public void testGetAction() { @Test public void testGetAnomalyDetectorRequest() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, "value"); + GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, entity); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); @@ -167,7 +177,7 @@ public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { request.writeTo(out); StreamInput input = out.bytes().streamInput(); GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertNull(newRequest.getEntityValue()); + Assert.assertNull(newRequest.getEntity()); } @SuppressWarnings("unchecked") @@ -203,12 +213,14 @@ public void testGetAnomalyDetectorResponse() throws IOException { Assert.assertEquals(map1.get("name"), detector.getName()); } + @SuppressWarnings("unchecked") @Test public void testGetAnomalyDetectorProfileResponse() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); AnomalyDetectorJob adJob = TestHelpers.randomAnomalyDetectorJob(); - EntityProfile entityProfile = new EntityProfile.Builder("catField", "app-0").build(); + InitProgressProfile initProgress = new InitProgressProfile("99%", 2L, 2); + EntityProfile entityProfile = new EntityProfile.Builder().initProgress(initProgress).build(); GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( 4321, "1234", @@ -230,8 +242,19 @@ public void testGetAnomalyDetectorProfileResponse() throws IOException { XContentBuilder builder = TestHelpers.builder(); Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + // {init_progress={percentage=99%, estimated_minutes_left=2, needed_shingles=2}} Map map = TestHelpers.XContentBuilderToMap(builder); - Assert.assertEquals(map.get(EntityProfile.CATEGORY_FIELD), "catField"); - Assert.assertEquals(map.get(EntityProfile.ENTITY_VALUE), "app-0"); + Map parsedInitProgress = (Map) (map.get(CommonName.INIT_PROGRESS)); + Assert.assertEquals(initProgress.getPercentage(), parsedInitProgress.get(InitProgressProfile.PERCENTAGE).toString()); + Assert + .assertEquals( + String.valueOf(initProgress.getEstimatedMinutesLeft()), + parsedInitProgress.get(InitProgressProfile.ESTIMATED_MINUTES_LEFT).toString() + ); + Assert + .assertEquals( + String.valueOf(initProgress.getNeededDataPoints()), + parsedInitProgress.get(InitProgressProfile.NEEDED_SHINGLES).toString() + ); } } diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java new file mode 100644 index 000000000..f7c403e01 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -0,0 +1,766 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.stubbing.Answer; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.common.exception.LimitExceededException; +import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.feature.CompositeRetriever; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.AnomalyDetectionIndices; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelPartitioner; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.IntervalTimeConfiguration; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.StatNames; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.metrics.InternalMin; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class MultiEntityResultTests extends AbstractADTest { + private AnomalyResultTransportAction action; + private AnomalyResultRequest request; + private TransportInterceptor entityResultInterceptor; + private Clock clock; + private AnomalyDetector detector; + private NodeStateManager stateManager; + private static Settings settings; + private TransportService transportService; + private SearchFeatureDao searchFeatureDao; + private Client client; + private FeatureManager featureQuery; + private ModelManager normalModelManager; + private ModelPartitioner normalModelPartitioner; + private HashRing hashRing; + private ClusterService clusterService; + private IndexNameExpressionResolver indexNameResolver; + private ADCircuitBreakerService adCircuitBreakerService; + private ADStats adStats; + private ThreadPool mockThreadPool; + private String detectorId; + private Instant now; + private String modelId; + private CacheProvider provider; + private AnomalyDetectionIndices indexUtil; + private ResultWriteWorker resultWriteQueue; + private CheckpointReadWorker checkpointReadQueue; + private EntityColdStarter coldStarer; + private ColdEntityWorker coldEntityQueue; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings({ "serial", "unchecked" }) + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + now = Instant.now(); + clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + detectorId = "123"; + modelId = "abc"; + String categoryField = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); + + stateManager = mock(NodeStateManager.class); + // make sure parameters are not null, otherwise this mock won't get invoked + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + + // AnomalyDetector detector = TestHelpers + // .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), true, true); + + settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); + + // make sure end time is larger enough than Clock.systemUTC().millis() to get PageIterator.hasNext() to pass + request = new AnomalyResultRequest(detectorId, 100, Clock.systemUTC().millis() + 100_000); + + transportService = mock(TransportService.class); + + client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(settings); + mockThreadPool = mock(ThreadPool.class); + setUpADThreadPool(mockThreadPool); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + featureQuery = mock(FeatureManager.class); + + normalModelManager = mock(ModelManager.class); + + normalModelPartitioner = mock(ModelPartitioner.class); + + hashRing = mock(HashRing.class); + + Set> anomalyResultSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + anomalyResultSetting.add(MAX_ENTITIES_PER_QUERY); + anomalyResultSetting.add(PAGE_SIZE); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, anomalyResultSetting); + + DiscoveryNode discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); + + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + IndexUtils indexUtils = new IndexUtils(client, mock(ClientUtil.class), clusterService, indexNameResolver); + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + adStats = new ADStats(statsMap); + + searchFeatureDao = mock(SearchFeatureDao.class); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry() + ); + + provider = mock(CacheProvider.class); + EntityCache entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + when(entityCache.get(any(), any())) + .thenReturn(MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build())); + when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); + + indexUtil = mock(AnomalyDetectionIndices.class); + resultWriteQueue = mock(ResultWriteWorker.class); + checkpointReadQueue = mock(CheckpointReadWorker.class); + + coldStarer = mock(EntityColdStarter.class); + coldEntityQueue = mock(ColdEntityWorker.class); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + public void testColdStartEndRunException() { + when(stateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); + } + + // a handler that forwards response or exception received from network + private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse(response); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private TransportResponseHandler unackEntityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new AcknowledgedResponse(false)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private void setUpEntityResult() { + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool + ); + + when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); + } + + @SuppressWarnings("unchecked") + public void setUpNormlaStateManager() throws IOException { + ClientUtil clientUtil = mock(ClientUtil.class); + + AnomalyDetector detector = TestHelpers + .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), true, true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + ModelPartitioner modelPartitioner = mock(ModelPartitioner.class); + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + modelPartitioner + ); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry() + ); + } + + /** + * Test query error causes EndRunException but not end now + * @throws InterruptedException when the await are interrupted + * @throws IOException when failing to create anomaly detector + */ + public void testQueryErrorEndRunNotNow() throws InterruptedException, IOException { + setUpNormlaStateManager(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + String allShardsFailedMsg = "all shards failed"; + // make PageIterator.next return failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new SearchPhaseExecutionException( + "search", + allShardsFailedMsg, + new ShardSearchFailure[] { new ShardSearchFailure(new IllegalArgumentException("blah")) } + ) + ); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); + // wrapped INVALID_SEARCH_QUERY_MSG around SearchPhaseExecutionException by convertedQueryFailureException + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(CommonErrorMessages.INVALID_SEARCH_QUERY_MSG)); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(allShardsFailedMsg)); + // not end now + assertTrue(!((EndRunException) e).isEndNow()); + } + + public void testIndexNotFound() throws InterruptedException, IOException { + setUpNormlaStateManager(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + // make PageIterator.next return failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("", "")); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); + assertThat( + "actual message: " + e.getMessage(), + e.getMessage(), + containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) + ); + assertTrue(!((EndRunException) e).isEndNow()); + } + + public void testEmptyFeatures() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(createEmptyResponse()); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + + AnomalyResultResponse response2 = listener2.actionGet(10000L); + assertEquals(Double.NaN, response2.getAnomalyGrade(), 0.01); + } + + /** + * + * @return an empty response + */ + private SearchResponse createEmptyResponse() { + CompositeAggregation emptyComposite = mock(CompositeAggregation.class); + when(emptyComposite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(emptyComposite.afterKey()).thenReturn(null); + // empty bucket + when(emptyComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return new ArrayList(); }); + Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite)); + SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1); + return new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + } + + private CountDownLatch setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor + ) { + // set up a non-empty response + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + Map afterKey = new HashMap<>(); + afterKey.put("service", "app_0"); + afterKey.put("host", "server_3"); + when(composite.afterKey()).thenReturn(afterKey); + + String featureID = detector.getFeatureAttributes().get(0).getId(); + List compositeBuckets = new ArrayList<>(); + CompositeAggregation.Bucket bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_1")); + List aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + Aggregations aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_2")); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_3")); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + when(composite.getBuckets()).thenAnswer((Answer>) invocation -> { return compositeBuckets; }); + Aggregations aggs = new Aggregations(Collections.singletonList(composite)); + + SearchResponseSections sections = new SearchResponseSections(SearchHits.empty(), aggs, null, false, null, null, 1); + SearchResponse response = new SearchResponse(sections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + CountDownLatch inProgress = new CountDownLatch(2); + AtomicBoolean firstCalled = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (firstCalled.get()) { + listener.onResponse(createEmptyResponse()); + inProgress.countDown(); + } else { + listener.onResponse(response); + firstCalled.set(true); + inProgress.countDown(); + } + return null; + }).when(client).search(any(), any()); + + entityResultInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @SuppressWarnings("unchecked") + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (action.equals(EntityResultAction.NAME)) { + sender + .sendRequest( + connection, + action, + request, + options, + interceptor.apply((TransportResponseHandler) handler) + ); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + setupTestNodes(entityResultInterceptor, settings, MAX_ENTITIES_PER_QUERY, PAGE_SIZE); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + xContentRegistry() + ); + + return inProgress; + } + + public void testNonEmptyFeatures() throws InterruptedException { + CountDownLatch inProgress = setUpTransportInterceptor(this::entityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since we have 3 results in the first page + verify(resultWriteQueue, times(3)).put(any()); + } + + @SuppressWarnings("unchecked") + public void testCircuitBreakerOpen() throws InterruptedException { + ClientUtil clientUtil = mock(ClientUtil.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + ModelPartitioner modelPartitioner = mock(ModelPartitioner.class); + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + modelPartitioner + ); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry() + ); + + CountDownLatch inProgress = setUpTransportInterceptor(this::entityResultHandler); + + ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + when(openBreaker.isOpen()).thenReturn(true); + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + openBreaker, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, LimitExceededException.class, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); + } + + // public void testNotAck() { + // setUpTransportInterceptor(this::unackEntityResultHandler); + // setUpEntityResult(); + // + // PlainActionFuture listener = new PlainActionFuture<>(); + // + // action.doExecute(null, request, listener); + // + // assertException(listener, InternalFailure.class, AnomalyResultTransportAction.NO_ACK_ERR); + // verify(stateManager, times(1)).addPressure(anyString()); + // } +} diff --git a/src/test/java/org/opensearch/ad/transport/MultientityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultientityResultTests.java deleted file mode 100644 index 75981ca75..000000000 --- a/src/test/java/org/opensearch/ad/transport/MultientityResultTests.java +++ /dev/null @@ -1,510 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package org.opensearch.ad.transport; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.time.Clock; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Optional; -import java.util.function.Function; - -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.opensearch.action.ActionListener; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.AbstractADTest; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.TestHelpers; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.common.exception.AnomalyDetectionException; -import org.opensearch.ad.common.exception.EndRunException; -import org.opensearch.ad.common.exception.InternalFailure; -import org.opensearch.ad.common.exception.LimitExceededException; -import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; -import org.opensearch.ad.indices.AnomalyDetectionIndices; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelPartitioner; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.StatNames; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.Transport; -import org.opensearch.transport.TransportException; -import org.opensearch.transport.TransportInterceptor; -import org.opensearch.transport.TransportRequest; -import org.opensearch.transport.TransportRequestOptions; -import org.opensearch.transport.TransportResponse; -import org.opensearch.transport.TransportResponseHandler; -import org.opensearch.transport.TransportService; - -import test.org.opensearch.ad.util.MLUtil; - -public class MultientityResultTests extends AbstractADTest { - private AnomalyResultTransportAction action; - private AnomalyResultRequest request; - private TransportInterceptor entityResultInterceptor; - private Clock clock; - private AnomalyDetector detector; - private NodeStateManager stateManager; - private static Settings settings; - private TransportService transportService; - private SearchFeatureDao searchFeatureDao; - private Client client; - private FeatureManager featureQuery; - private ModelManager normalModelManager; - private ModelPartitioner normalModelPartitioner; - private HashRing hashRing; - private ClusterService clusterService; - private IndexNameExpressionResolver indexNameResolver; - private ADCircuitBreakerService adCircuitBreakerService; - private ADStats adStats; - private ThreadPool mockThreadPool; - private String detectorId; - private Instant now; - private String modelId; - private MultiEntityResultHandler anomalyResultHandler; - private CheckpointDao checkpointDao; - private CacheProvider provider; - private AnomalyDetectionIndices indexUtil; - - @BeforeClass - public static void setUpBeforeClass() { - setUpThreadPool(AnomalyResultTests.class.getSimpleName()); - } - - @AfterClass - public static void tearDownAfterClass() { - tearDownThreadPool(); - } - - @SuppressWarnings({ "serial", "unchecked" }) - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - now = Instant.now(); - clock = mock(Clock.class); - when(clock.instant()).thenReturn(now); - - detectorId = "123"; - modelId = "abc"; - String categoryField = "a"; - detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); - - stateManager = mock(NodeStateManager.class); - // make sure parameters are not null, otherwise this mock won't get invoked - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); - listener.onResponse(Optional.of(detector)); - return null; - }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); - when(stateManager.getLastIndexThrottledTime()).thenReturn(Instant.MIN); - - settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); - - request = new AnomalyResultRequest(detectorId, 100, 200); - - transportService = mock(TransportService.class); - - client = mock(Client.class); - ThreadContext threadContext = new ThreadContext(settings); - mockThreadPool = mock(ThreadPool.class); - setUpADThreadPool(mockThreadPool); - when(client.threadPool()).thenReturn(mockThreadPool); - when(mockThreadPool.getThreadContext()).thenReturn(threadContext); - - featureQuery = mock(FeatureManager.class); - - normalModelManager = mock(ModelManager.class); - when(normalModelManager.getEntityModelId(anyString(), anyString())).thenReturn(modelId); - - normalModelPartitioner = mock(ModelPartitioner.class); - - hashRing = mock(HashRing.class); - - clusterService = mock(ClusterService.class); - - indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - - adCircuitBreakerService = mock(ADCircuitBreakerService.class); - when(adCircuitBreakerService.isOpen()).thenReturn(false); - - IndexUtils indexUtils = new IndexUtils(client, mock(ClientUtil.class), clusterService, indexNameResolver); - Map> statsMap = new HashMap>() { - { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - } - }; - adStats = new ADStats(indexUtils, normalModelManager, statsMap); - - searchFeatureDao = mock(SearchFeatureDao.class); - - action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, - settings, - client, - stateManager, - featureQuery, - normalModelManager, - normalModelPartitioner, - hashRing, - clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - mockThreadPool, - searchFeatureDao - ); - - anomalyResultHandler = mock(MultiEntityResultHandler.class); - checkpointDao = mock(CheckpointDao.class); - provider = mock(CacheProvider.class); - indexUtil = mock(AnomalyDetectionIndices.class); - } - - @Override - @After - public final void tearDown() throws Exception { - tearDownTestNodes(); - super.tearDown(); - } - - @SuppressWarnings("unchecked") - public void testQueryError() { - // non-EndRunException won't stop action from running - when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); - - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); - listener - .onFailure( - new EndRunException( - detectorId, - CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, - new NoSuchElementException("No value present"), - false - ) - ); - return null; - }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - - verify(stateManager, times(1)).getAnomalyDetector(anyString(), any(ActionListener.class)); - - assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); - } - - public void testIndexNotFound() { - // non-EndRunException won't stop action from running - when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); - - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); - listener.onFailure(new IndexNotFoundException("", "")); - return null; - }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - - assertException(listener, EndRunException.class, AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG); - } - - public void testColdStartEndRunException() { - when(stateManager.fetchColdStartException(anyString())) - .thenReturn( - Optional - .of( - new EndRunException( - detectorId, - CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, - new NoSuchElementException("No value present"), - false - ) - ) - ); - PlainActionFuture listener = new PlainActionFuture<>(); - action.doExecute(null, request, listener); - assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); - } - - public void testEmptyFeatures() { - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); - listener.onResponse(new HashMap()); - return null; - }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - - AnomalyResultResponse response = listener.actionGet(10000L); - assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); - } - - private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { - return new TransportResponseHandler() { - @Override - public T read(StreamInput in) throws IOException { - return handler.read(in); - } - - @Override - @SuppressWarnings("unchecked") - public void handleResponse(T response) { - handler.handleResponse(response); - } - - @Override - public void handleException(TransportException exp) { - handler.handleException(exp); - } - - @Override - public String executor() { - return handler.executor(); - } - }; - } - - private TransportResponseHandler unackEntityResultHandler(TransportResponseHandler handler) { - return new TransportResponseHandler() { - @Override - public T read(StreamInput in) throws IOException { - return handler.read(in); - } - - @Override - @SuppressWarnings("unchecked") - public void handleResponse(T response) { - handler.handleResponse((T) new AcknowledgedResponse(false)); - } - - @Override - public void handleException(TransportException exp) { - handler.handleException(exp); - } - - @Override - public String executor() { - return handler.executor(); - } - }; - } - - private void setUpEntityResult() { - // register entity result action - new EntityResultTransportAction( - new ActionFilters(Collections.emptySet()), - // since we send requests to testNodes[1] - testNodes[1].transportService, - normalModelManager, - adCircuitBreakerService, - anomalyResultHandler, - checkpointDao, - provider, - stateManager, - settings, - clock, - indexUtil - ); - - EntityCache entityCache = mock(EntityCache.class); - when(provider.get()).thenReturn(entityCache); - when(entityCache.get(any(), any(), any(), anyString())).thenReturn(MLUtil.randomNonEmptyModelState()); - - when(normalModelManager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) - .thenReturn(new ThresholdingResult(0, 1, 1)); - } - - private void setUpTransportInterceptor( - Function, TransportResponseHandler> interceptor - ) { - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); - Map features = new HashMap(); - features.put("1.0.2.3", new double[] { 0 }); - features.put("2.0.2.3", new double[] { 1 }); - listener.onResponse(features); - return null; - }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); - - entityResultInterceptor = new TransportInterceptor() { - @Override - public AsyncSender interceptSender(AsyncSender sender) { - return new AsyncSender() { - @SuppressWarnings("unchecked") - @Override - public void sendRequest( - Transport.Connection connection, - String action, - TransportRequest request, - TransportRequestOptions options, - TransportResponseHandler handler - ) { - if (action.equals(EntityResultAction.NAME)) { - sender - .sendRequest( - connection, - action, - request, - options, - interceptor.apply((TransportResponseHandler) handler) - ); - } else { - sender.sendRequest(connection, action, request, options, handler); - } - } - }; - } - }; - - setupTestNodes(settings, entityResultInterceptor); - - // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); - - TransportService realTransportService = testNodes[0].transportService; - ClusterService realClusterService = testNodes[0].clusterService; - - action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - realTransportService, - settings, - client, - stateManager, - featureQuery, - normalModelManager, - normalModelPartitioner, - hashRing, - realClusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, - threadPool, - searchFeatureDao - ); - } - - public void testNonEmptyFeatures() { - setUpTransportInterceptor(this::entityResultHandler); - setUpEntityResult(); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - - AnomalyResultResponse response = listener.actionGet(10000L); - assertEquals(0d, response.getAnomalyGrade(), 0.01); - } - - public void testCircuitBreakerOpen() { - setUpTransportInterceptor(this::entityResultHandler); - - ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); - when(openBreaker.isOpen()).thenReturn(true); - // register entity result action - new EntityResultTransportAction( - new ActionFilters(Collections.emptySet()), - // since we send requests to testNodes[1] - testNodes[1].transportService, - normalModelManager, - openBreaker, - anomalyResultHandler, - checkpointDao, - provider, - stateManager, - settings, - clock, - indexUtil - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - assertException(listener, LimitExceededException.class, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); - } - - public void testNotAck() { - setUpTransportInterceptor(this::unackEntityResultHandler); - setUpEntityResult(); - - PlainActionFuture listener = new PlainActionFuture<>(); - - action.doExecute(null, request, listener); - - assertException(listener, InternalFailure.class, AnomalyResultTransportAction.NO_ACK_ERR); - verify(stateManager, times(1)).addPressure(anyString()); - } -} diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java index 4e38dc996..bc012e1ae 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java @@ -73,7 +73,7 @@ public void testPreviewRequest() throws Exception { public void testPreviewResponse() throws Exception { BytesStreamOutput out = new BytesStreamOutput(); AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); - AnomalyResult result = TestHelpers.randomMultiEntityAnomalyDetectResult(0.8d, 0d); + AnomalyResult result = TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d); PreviewAnomalyDetectorResponse response = new PreviewAnomalyDetectorResponse(ImmutableList.of(result), detector); response.writeTo(out); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTests.java index 7768bfbe4..9e670d231 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java @@ -31,6 +31,7 @@ import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -46,7 +47,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; @@ -147,7 +148,14 @@ public void testProfileNodeRequest() throws IOException { public void testProfileNodeResponse() throws IOException, JsonPathNotFoundException { // Test serialization - ProfileNodeResponse profileNodeResponse = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize, 0, 0); + ProfileNodeResponse profileNodeResponse = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0, + 0, + new ArrayList<>() + ); BytesStreamOutput output = new BytesStreamOutput(); profileNodeResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); @@ -163,7 +171,7 @@ public void testProfileNodeResponse() throws IOException, JsonPathNotFoundExcept for (Map.Entry profile : modelSizeMap1.entrySet()) { assertEquals( "toXContent has the wrong model size", - JsonDeserializer.getLongValue(json, ProfileNodeResponse.MODEL_SIZE_IN_BYTES, profile.getKey()), + JsonDeserializer.getLongValue(json, CommonName.MODEL_SIZE_IN_BYTES, profile.getKey()), profile.getValue().longValue() ); } @@ -194,8 +202,15 @@ public void testProfileRequest() throws IOException { @Test public void testProfileResponse() throws IOException, JsonPathNotFoundException { - ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize, 0, 0); - ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, -1, 0, 0); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0, + 0, + new ArrayList<>() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, -1, 0, 0, new ArrayList<>()); List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); List failures = Collections.emptyList(); ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); @@ -204,7 +219,7 @@ public void testProfileResponse() throws IOException, JsonPathNotFoundException assertEquals(shingleSize, profileResponse.getShingleSize()); assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); assertEquals(2, profileResponse.getModelProfile().length); - for (ModelProfile profile : profileResponse.getModelProfile()) { + for (ModelProfileOnNode profile : profileResponse.getModelProfile()) { assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); assertEquals(modelSize, profile.getModelSize()); if (node1.equals(profile.getNodeId())) { @@ -240,20 +255,20 @@ public void testProfileResponse() throws IOException, JsonPathNotFoundException JsonElement element = modelsJson.get(i); assertTrue( "toXContent has the wrong model id", - JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model1Id) - || JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model0Id) + JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_KEY).equals(model1Id) + || JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_KEY).equals(model0Id) ); assertEquals( "toXContent has the wrong model size", - JsonDeserializer.getLongValue(element, ModelProfile.MODEL_SIZE_IN_BYTES), + JsonDeserializer.getLongValue(element, CommonName.MODEL_SIZE_IN_BYTES), modelSize ); - if (JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model1Id)) { - assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfile.NODE_ID), node1); + if (JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_KEY).equals(model1Id)) { + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfileOnNode.NODE_ID), node1); } else { - assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfile.NODE_ID), node2); + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfileOnNode.NODE_ID), node2); } } diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java index 601ad3528..36f30b3c1 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java @@ -48,6 +48,8 @@ import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.Entity; +import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.transport.TransportService; @@ -82,10 +84,22 @@ public void setUp() throws Exception { when(cache.getActiveEntities(anyString())).thenReturn(activeEntities); when(cache.getTotalUpdates(anyString())).thenReturn(totalUpdates); Map multiEntityModelSizeMap = new HashMap<>(); - multiEntityModelSizeMap.put("T4c3dXUBj-2IZN7itix__entity_app_3", multiEntityModelSize); - multiEntityModelSizeMap.put("T4c3dXUBj-2IZN7itix__entity_app_2", multiEntityModelSize); + String modelId1 = "T4c3dXUBj-2IZN7itix__entity_app_3"; + String modelId2 = "T4c3dXUBj-2IZN7itix__entity_app_2"; + multiEntityModelSizeMap.put(modelId1, multiEntityModelSize); + multiEntityModelSizeMap.put(modelId2, multiEntityModelSize); when(cache.getModelSize(anyString())).thenReturn(multiEntityModelSizeMap); + List modelProfiles = new ArrayList<>(); + String field = "field"; + String fieldVal1 = "value1"; + String fieldVal2 = "value2"; + Entity entity1 = Entity.createSingleAttributeEntity(detectorId, field, fieldVal1); + Entity entity2 = Entity.createSingleAttributeEntity(detectorId, field, fieldVal2); + modelProfiles.add(new ModelProfile(modelId1, entity1, multiEntityModelSize)); + modelProfiles.add(new ModelProfile(modelId1, entity2, multiEntityModelSize)); + when(cache.getAllModelProfile(anyString())).thenReturn(modelProfiles); + Map modelSizes = new HashMap<>(); modelSizes.put(modelId, modelSize); when(modelManager.getModelSize(any(String.class))).thenReturn(modelSizes); @@ -109,7 +123,7 @@ public void testNewResponse() { DiscoveryNode node = clusterService().localNode(); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, node); - ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(node, new HashMap<>(), shingleSize, 0, 0); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(node, new HashMap<>(), shingleSize, 0, 0, new ArrayList<>()); List profileNodeResponses = Arrays.asList(profileNodeResponse1); List failures = new ArrayList<>(); @@ -178,7 +192,8 @@ public void testMultiEntityNodeOperation() { response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); assertEquals(activeEntities, response.getActiveEntities()); - assertEquals(2, response.getModelSize().size()); + assertEquals(null, response.getModelSize()); + assertEquals(2, response.getModelProfiles().size()); assertEquals(totalUpdates, response.getTotalUpdates()); } } diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index f69831b67..4f5b90947 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -50,7 +50,7 @@ import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ModelPartitioner; import org.opensearch.cluster.node.DiscoveryNode; @@ -313,7 +313,7 @@ public String executor() { } public void testGetRemoteNormalResponse() { - setupTestNodes(Settings.EMPTY, normalTransportInterceptor); + setupTestNodes(normalTransportInterceptor, Settings.EMPTY); try { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; @@ -341,7 +341,7 @@ public void testGetRemoteNormalResponse() { } public void testGetRemoteFailureResponse() { - setupTestNodes(Settings.EMPTY, failureTransportInterceptor); + setupTestNodes(failureTransportInterceptor, Settings.EMPTY); try { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; @@ -376,7 +376,7 @@ public void testResponseToXContent() throws IOException, JsonPathNotFoundExcepti public void testRequestToXContent() throws IOException, JsonPathNotFoundException { RCFPollingRequest response = new RCFPollingRequest(detectorId); String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - assertEquals(detectorId, JsonDeserializer.getTextValue(json, CommonMessageAttributes.ID_JSON_KEY)); + assertEquals(detectorId, JsonDeserializer.getTextValue(json, CommonName.ID_JSON_KEY)); } public void testNullDetectorId() { diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index f7e4f83d8..a1f94109c 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -48,7 +48,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.RcfResult; import org.opensearch.common.Strings; @@ -194,12 +194,8 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals(JsonDeserializer.getTextValue(json, CommonMessageAttributes.ID_JSON_KEY), request.getAdID()); - assertArrayEquals( - JsonDeserializer.getDoubleArrayValue(json, CommonMessageAttributes.FEATURE_JSON_KEY), - request.getFeatures(), - 0.001 - ); + assertEquals(JsonDeserializer.getTextValue(json, CommonName.ID_JSON_KEY), request.getAdID()); + assertArrayEquals(JsonDeserializer.getDoubleArrayValue(json, CommonName.FEATURE_JSON_KEY), request.getFeatures(), 0.001); } @SuppressWarnings("unchecked") diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java index 99477ebe6..440ab0ecd 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java @@ -44,7 +44,7 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; -import org.opensearch.ad.constant.CommonMessageAttributes; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.common.Strings; @@ -133,12 +133,8 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals( - JsonDeserializer.getDoubleValue(json, CommonMessageAttributes.ANOMALY_GRADE_JSON_KEY), - response.getAnomalyGrade(), - 0.001 - ); - assertEquals(JsonDeserializer.getDoubleValue(json, CommonMessageAttributes.CONFIDENCE_JSON_KEY), response.getConfidence(), 0.001); + assertEquals(JsonDeserializer.getDoubleValue(json, CommonName.ANOMALY_GRADE_JSON_KEY), response.getAnomalyGrade(), 0.001); + assertEquals(JsonDeserializer.getDoubleValue(json, CommonName.CONFIDENCE_JSON_KEY), response.getConfidence(), 0.001); } public void testEmptyID() { @@ -163,7 +159,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals(JsonDeserializer.getTextValue(json, CommonMessageAttributes.ID_JSON_KEY), request.getAdID()); - assertEquals(JsonDeserializer.getDoubleValue(json, CommonMessageAttributes.RCF_SCORE_JSON_KEY), request.getRCFScore(), 0.001); + assertEquals(JsonDeserializer.getTextValue(json, CommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getDoubleValue(json, CommonName.RCF_SCORE_JSON_KEY), request.getRCFScore(), 0.001); } } diff --git a/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java index 204cb5450..8bc9c5049 100644 --- a/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java +++ b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java @@ -28,7 +28,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.opensearch.ad.TestHelpers.randomMutlEntityAnomalyDetectResult; +import static org.opensearch.ad.TestHelpers.randomHCADAnomalyDetectResult; import java.util.ArrayList; import java.util.concurrent.CountDownLatch; @@ -66,8 +66,8 @@ public void testEmptyResponse() throws InterruptedException { @SuppressWarnings("unchecked") public void testForceResponse() { - AnomalyResult anomalyResult1 = randomMutlEntityAnomalyDetectResult(0.25, 0.25, "error"); - AnomalyResult anomalyResult2 = randomMutlEntityAnomalyDetectResult(0.5, 0.5, "error"); + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); EntityAnomalyResult entityAnomalyResult1 = new EntityAnomalyResult(new ArrayList() { { diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index 21200af2a..ceab9a95e 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -197,7 +197,14 @@ private void setUpProfileAction() { ActionListener listener = (ActionListener) args[2]; - ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, new HashMap<>(), shingleSize, 0, 0); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + new HashMap<>(), + shingleSize, + 0, + 0, + new ArrayList<>() + ); List profileNodeResponses = Arrays.asList(profileNodeResponse1); listener.onResponse(new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, Collections.emptyList())); diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java b/src/test/java/test/org/opensearch/ad/util/FakeNode.java index d0d779ba4..bf7a58062 100644 --- a/src/test/java/test/org/opensearch/ad/util/FakeNode.java +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java @@ -33,12 +33,15 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; import org.apache.lucene.util.SetOnce; import org.opensearch.Version; import org.opensearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; @@ -51,6 +54,8 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.lease.Releasable; import org.opensearch.common.network.NetworkService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.BoundTransportAddress; import org.opensearch.common.transport.TransportAddress; @@ -65,15 +70,23 @@ import org.opensearch.transport.nio.MockNioTransport; public class FakeNode implements Releasable { - public FakeNode(String name, ThreadPool threadPool, Settings settings, TransportInterceptor transportInterceptor) { + protected static final Logger LOG = (Logger) LogManager.getLogger(FakeNode.class); + + public FakeNode( + String name, + ThreadPool threadPool, + final Settings nodeSettings, + final Set> settingsSet, + TransportInterceptor transportInterceptor + ) { final Function boundTransportAddressDiscoveryNodeFunction = address -> { discoveryNode.set(new DiscoveryNode(name, address.publishAddress(), emptyMap(), emptySet(), Version.CURRENT)); return discoveryNode.get(); }; transportService = new TransportService( - settings, + Settings.EMPTY, new MockNioTransport( - settings, + Settings.EMPTY, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()), @@ -103,7 +116,10 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool }; transportService.start(); - clusterService = createClusterService(threadPool, discoveryNode.get()); + Set> internalSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + internalSettings.addAll(settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(nodeSettings, internalSettings); + clusterService = createClusterService(threadPool, discoveryNode.get(), clusterSettings); clusterService.addStateApplier(transportService.getTaskManager()); ActionFilters actionFilters = new ActionFilters(emptySet()); transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); @@ -111,8 +127,8 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool transportService.acceptIncomingRequests(); } - public FakeNode(String name, ThreadPool threadPool, Settings settings) { - this(name, threadPool, settings, TransportService.NOOP_TRANSPORT_INTERCEPTOR); + public FakeNode(String name, ThreadPool threadPool, Set> settings) { + this(name, threadPool, Settings.EMPTY, settings, TransportService.NOOP_TRANSPORT_INTERCEPTOR); } public final ClusterService clusterService; diff --git a/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java b/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java index 981f99616..8b7d39c58 100644 --- a/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java +++ b/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java @@ -407,6 +407,25 @@ public static double getDoubleValue(String jsonString, String... paths) throws J throw new JsonPathNotFoundException(); } + /** + * Search a float number inside a JSON string matching the input path + * expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @return the matching double number + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + public static double getFloatValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null) { + return jsonNode.getAsFloat(); + } + throw new JsonPathNotFoundException(); + } + /** * Search an int number inside a JSON string matching the input path expression * diff --git a/src/test/java/test/org/opensearch/ad/util/MLUtil.java b/src/test/java/test/org/opensearch/ad/util/MLUtil.java index aacd369ac..902b5c955 100644 --- a/src/test/java/test/org/opensearch/ad/util/MLUtil.java +++ b/src/test/java/test/org/opensearch/ad/util/MLUtil.java @@ -38,6 +38,7 @@ import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.ml.ThresholdingModel; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.AnomalyDetectorSettings; import com.amazon.randomcutforest.RandomCutForest; @@ -70,48 +71,33 @@ public static Queue createQueueSamples(int size) { return res; } - public static ModelState randomModelState() { - return randomModelState(random.nextBoolean(), random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); - } + public static ModelState randomModelState(RandomModelStateConfig config) { + boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false; + float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat(); + String detectorId = config.getDetectorId() != null ? config.getDetectorId() : randomString(15); + int sampleSize = config.getSampleSize() != null ? config.getSampleSize() : random.nextInt(minSampleSize); + Clock clock = config.getClock() != null ? config.getClock() : Clock.systemUTC(); - public static ModelState randomModelState(boolean fullModel, float priority, String modelId, int sampleSize) { - String detectorId = randomString(5); EntityModel model = null; if (fullModel) { - model = createNonEmptyModel(modelId, sampleSize); + model = createNonEmptyModel(detectorId, sampleSize); } else { - model = createEmptyModel(modelId, sampleSize); + model = createEmptyModel(Entity.createSingleAttributeEntity(detectorId, "", ""), sampleSize); } - return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), Clock.systemUTC(), priority); - } - - public static ModelState randomNonEmptyModelState() { - return randomModelState(true, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); - } - - public static ModelState randomEmptyModelState() { - return randomModelState(false, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); - } - - public static ModelState randomModelState(float priority, String modelId) { - return randomModelState(random.nextBoolean(), priority, modelId, random.nextInt(minSampleSize)); - } - - public static ModelState randomModelStateWithSample(boolean fullModel, int sampleSize) { - return randomModelState(fullModel, random.nextFloat(), randomString(15), sampleSize); + return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority); } - public static EntityModel createEmptyModel(String modelId, int sampleSize) { + public static EntityModel createEmptyModel(Entity entity, int sampleSize) { Queue samples = createQueueSamples(sampleSize); - return new EntityModel(modelId, samples, null, null); + return new EntityModel(entity, samples, null, null); } - public static EntityModel createEmptyModel(String modelId) { - return createEmptyModel(modelId, random.nextInt(minSampleSize)); + public static EntityModel createEmptyModel(Entity entity) { + return createEmptyModel(entity, random.nextInt(minSampleSize)); } - public static EntityModel createNonEmptyModel(String modelId, int sampleSize) { + public static EntityModel createNonEmptyModel(String detectorId, int sampleSize) { Queue samples = createQueueSamples(sampleSize); RandomCutForest rcf = RandomCutForest .builder() @@ -140,10 +126,10 @@ public static EntityModel createNonEmptyModel(String modelId, int sampleSize) { AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES ); threshold.train(nonZeroScores); - return new EntityModel(modelId, samples, rcf, threshold); + return new EntityModel(Entity.createSingleAttributeEntity(detectorId, "", ""), samples, rcf, threshold); } - public static EntityModel createNonEmptyModel(String modelId) { - return createNonEmptyModel(modelId, random.nextInt(minSampleSize)); + public static EntityModel createNonEmptyModel(String detectorId) { + return createNonEmptyModel(detectorId, random.nextInt(minSampleSize)); } } diff --git a/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java b/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java new file mode 100644 index 000000000..cefd81dc5 --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package test.org.opensearch.ad.util; + +import java.time.Clock; + +public class RandomModelStateConfig { + private final Boolean fullModel; + private final Float priority; + private final String detectorId; + private final Integer sampleSize; + private final Clock clock; + + private RandomModelStateConfig(Builder builder) { + this.fullModel = builder.fullModel; + this.priority = builder.priority; + this.detectorId = builder.detectorId; + this.sampleSize = builder.sampleSize; + this.clock = builder.clock; + } + + public Boolean getFullModel() { + return fullModel; + } + + public Float getPriority() { + return priority; + } + + public String getDetectorId() { + return detectorId; + } + + public Integer getSampleSize() { + return sampleSize; + } + + public Clock getClock() { + return clock; + } + + public static class Builder { + private Boolean fullModel = null; + private Float priority = null; + private String detectorId = null; + private Integer sampleSize = null; + private Clock clock = null; + + public Builder fullModel(boolean fullModel) { + this.fullModel = fullModel; + return this; + } + + public Builder priority(float priority) { + this.priority = priority; + return this; + } + + public Builder detectorId(String detectorId) { + this.detectorId = detectorId; + return this; + } + + public Builder sampleSize(int sampleSize) { + this.sampleSize = sampleSize; + return this; + } + + public Builder clock(Clock clock) { + this.clock = clock; + return this; + } + + public RandomModelStateConfig build() { + RandomModelStateConfig config = new RandomModelStateConfig(this); + return config; + } + } +}