Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add rule validation in AnomalyDetector constructor #1342

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
'org.opensearch.timeseries.transport.ResultBulkTransportAction',
'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler',
'org.opensearch.timeseries.transport.handler.ResultIndexingHandler',
Expand Down
26 changes: 1 addition & 25 deletions src/main/java/org/opensearch/ad/ml/ADModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
Expand All @@ -42,7 +41,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.timeseries.AnalysisModelSize;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.common.exception.ResourceNotFoundException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
Expand All @@ -52,7 +50,6 @@
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.settings.TimeSeriesSettings;
import org.opensearch.timeseries.util.DateUtils;
Expand All @@ -69,9 +66,7 @@
* A facade managing ML operations and models.
*/
public class ADModelManager extends
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart>
implements
AnalysisModelSize {
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart> {
protected static final String ENTITY_SAMPLE = "sp";
protected static final String ENTITY_RCF = "rcf";
protected static final String ENTITY_THRESHOLD = "th";
Expand Down Expand Up @@ -594,25 +589,6 @@ public List<ThresholdingResult> getPreviewResults(Features features, AnomalyDete
}).collect(Collectors.toList());
}

/**
* Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB).
* @param detectorId detector id
* @return a map of model id to its memory size
*/
@Override
public Map<String, Long> getModelSize(String detectorId) {
Map<String, Long> res = new HashMap<>();
res.putAll(forests.getModelSize(detectorId));
thresholds
.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId))
.forEach(entry -> {
res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes());
});
return res;
}

/**
* Get a RCF model's total updates.
* @param modelId the RCF model's id
Expand Down
121 changes: 121 additions & 0 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -109,6 +110,7 @@ public Integer getShingleSize(Integer customShingleSize) {
@Deprecated
public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range";
public static final String RULES_FIELD = "rules";
private static final String SUPPRESSION_RULE_ISSUE_PREFIX = "Suppression Rule Error: ";

protected String detectorType;

Expand Down Expand Up @@ -229,6 +231,8 @@ public AnomalyDetector(
issueType = ValidationIssueType.CATEGORY;
}

validateRules(features, rules);

checkAndThrowValidationErrors(ValidationAspect.DETECTOR);

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();
Expand Down Expand Up @@ -720,4 +724,121 @@ private static Boolean onlyParseBooleanValue(XContentParser parser) throws IOExc
}
return null;
}

/**
* Validates each condition in the list of rules against the list of features.
* Checks that:
* - The feature name exists in the list of features.
* - The related feature is enabled.
* - The value is not NaN and is positive.
*
* @param features The list of available features. Must not be null.
* @param rules The list of rules containing conditions to validate. Can be null.
*/
private void validateRules(List<Feature> features, List<Rule> rules) {
// Null check for rules
if (rules == null || rules.isEmpty()) {
return; // No suppression rules to validate; consider as valid
}

// Null check for features
if (features == null || features.isEmpty()) {
// Cannot proceed with validation if features are null but rules are not null
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "Features are not defined while suppression rules are provided.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Create a map of feature names to their enabled status for quick lookup
Map<String, Boolean> featureEnabledMap = new HashMap<>();
for (Feature feature : features) {
if (feature != null && feature.getName() != null) {
featureEnabledMap.put(feature.getName(), feature.getEnabled());
}
}

// Iterate over each rule
for (Rule rule : rules) {
if (rule == null || rule.getConditions() == null) {
// Invalid rule or conditions list is null
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A suppression rule or its conditions are not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Iterate over each condition in the rule
for (Condition condition : rule.getConditions()) {
if (condition == null) {
// Invalid condition
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition within a suppression rule is not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

String featureName = condition.getFeatureName();

// Check if the feature name is null
if (featureName == null) {
// Feature name is required
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition is missing the feature name.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature exists
if (!featureEnabledMap.containsKey(featureName)) {
// Feature does not exist
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "Feature \""
+ featureName
+ "\" specified in a suppression rule does not exist.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature is enabled
if (!featureEnabledMap.get(featureName)) {
// Feature is not enabled
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "Feature \""
+ featureName
+ "\" specified in a suppression rule is not enabled.";
this.issueType = ValidationIssueType.RULE;
return;
}

// other threshold types may not have value operand
ThresholdType thresholdType = condition.getThresholdType();
if (thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_MARGIN
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_MARGIN
|| thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_RATIO
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_RATIO) {
// Check if the value is not NaN
double value = condition.getValue();
if (Double.isNaN(value)) {
// Value is NaN
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "The threshold value for feature \""
+ featureName
+ "\" is not a valid number.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the value is positive
if (value <= 0) {
// Value is not positive
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "The threshold value for feature \""
+ featureName
+ "\" must be a positive number.";
this.issueType = ValidationIssueType.RULE;
return;
}
}
}
}

// All checks passed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

package org.opensearch.timeseries.ml;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.opensearch.timeseries.MemoryTracker;
Expand Down Expand Up @@ -55,48 +52,4 @@ public ModelState<RCFModelType> put(String key, ModelState<RCFModelType> value)
}
return previousAssociatedState;
}

/**
* Gets all of a config's model sizes hosted on a node
*
* @param configId config Id
* @return a map of model id to its memory size
*/
public Map<String, Long> getModelSize(String configId) {
Map<String, Long> res = new HashMap<>();
super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.forEach(entry -> {
Optional<RCFModelType> modelOptional = entry.getValue().getModel();
if (modelOptional.isPresent()) {
res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get()));
}
});
return res;
}

/**
* Checks if a model exists for the given config.
* @param configId Config Id
* @return `true` if the model exists, `false` otherwise.
*/
public boolean doesModelExist(String configId) {
return super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.anyMatch(n -> true);
}

public boolean hostIfPossible(String modelId, ModelState<RCFModelType> toUpdate) {
return Optional
.ofNullable(toUpdate)
.filter(state -> state.getModel().isPresent())
.filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get()))
.map(state -> {
super.put(modelId, toUpdate);
return true;
})
.orElse(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public enum ValidationIssueType implements Name {
SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD),
RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD),
DESCRIPTION(Config.DESCRIPTION_FIELD),
HISTORY(Config.HISTORY_INTERVAL_FIELD);
HISTORY(Config.HISTORY_INTERVAL_FIELD),
RULE(AnomalyDetector.RULES_FIELD);

private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ public class SuggestConfigParamRequest extends ActionRequest {
public SuggestConfigParamRequest(StreamInput in) throws IOException {
super(in);
context = in.readEnum(AnalysisType.class);
if (context.isAD()) {
if (getContext().isAD()) {
config = new AnomalyDetector(in);
} else if (context.isForecast()) {
} else if (getContext().isForecast()) {
config = new Forecaster(in);
} else {
throw new UnsupportedOperationException("This method is not supported");
Expand All @@ -55,7 +55,7 @@ public SuggestConfigParamRequest(AnalysisType context, Config config, String par
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeEnum(context);
out.writeEnum(getContext());
config.writeTo(out);
out.writeString(param);
out.writeTimeValue(requestTimeout);
Expand All @@ -77,4 +77,8 @@ public String getParam() {
public TimeValue getRequestTimeout() {
return requestTimeout;
}

public AnalysisType getContext() {
return context;
}
}
Loading
Loading