Skip to content

Commit

Permalink
add feature direction
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 17, 2024
1 parent 4c545ab commit a95ab17
Show file tree
Hide file tree
Showing 12 changed files with 204 additions and 51 deletions.
26 changes: 16 additions & 10 deletions src/main/java/org/opensearch/ad/ml/IgnoreSimilarExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@ public static ThresholdArrays processDetectorRules(AnomalyDetector detector) {
if (rules != null) {
for (Rule rule : rules) {
for (Condition condition : rule.getConditions()) {
processCondition(
condition,
featureNames,
baseDimension,
ignoreSimilarFromAbove,
ignoreSimilarFromBelow,
ignoreSimilarFromAboveByRatio,
ignoreSimilarFromBelowByRatio
);
if (condition.getThresholdType() != ThresholdType.ACTUAL_IS_BELOW_EXPECTED
|| condition.getThresholdType() != ThresholdType.ACTUAL_IS_OVER_EXPECTED) {
processCondition(
condition,
featureNames,
baseDimension,
ignoreSimilarFromAbove,
ignoreSimilarFromBelow,
ignoreSimilarFromAboveByRatio,
ignoreSimilarFromBelowByRatio
);
}
}
}
}
Expand Down Expand Up @@ -100,7 +103,10 @@ private static void processCondition(
int featureIndex = featureNames.indexOf(featureName);

ThresholdType thresholdType = condition.getThresholdType();
double value = condition.getValue();
Double value = condition.getValue();
if (value == null) {
value = 0d;
}

switch (thresholdType) {
case ACTUAL_OVER_EXPECTED_MARGIN:
Expand Down
11 changes: 10 additions & 1 deletion src/main/java/org/opensearch/ad/ml/ThresholdingResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
package org.opensearch.ad.ml;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import org.apache.commons.lang.builder.ToStringBuilder;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyResult;
import org.opensearch.ad.model.Rule;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.Entity;
Expand Down Expand Up @@ -331,6 +334,11 @@ public List<AnomalyResult> toIndexableResults(
String taskId,
String error
) {
List<Rule> rules = new ArrayList<>();
if (detector instanceof AnomalyDetector) {
AnomalyDetector detectorConfig = (AnomalyDetector) detector;
rules = detectorConfig.getRules();
}
return Collections
.singletonList(
AnomalyResult
Expand Down Expand Up @@ -358,7 +366,8 @@ public List<AnomalyResult> toIndexableResults(
likelihoodOfValues,
threshold,
currentData,
featureImputed
featureImputed,
rules
)
);
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,17 @@ private void validateRules(List<Feature> features, List<Rule> rules) {
this.issueType = ValidationIssueType.RULE;
return;
}
} else if (thresholdType == ThresholdType.ACTUAL_IS_BELOW_EXPECTED
|| thresholdType == ThresholdType.ACTUAL_IS_OVER_EXPECTED) {
// Check if both operator and value are null
if (condition.getOperator() != null || condition.getValue() != null) {
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "For threshold type \""
+ thresholdType
+ "\", both operator and value must be empty or null.";
this.issueType = ValidationIssueType.RULE;
return;
}
}
}
}
Expand Down
94 changes: 93 additions & 1 deletion src/main/java/org/opensearch/ad/model/AnomalyResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -312,6 +313,7 @@ public AnomalyResult(
* @param threshold Current threshold
* @param currentData imputed data if any
* @param featureImputed whether feature is imputed or not
* @param rules rules we apply on anomaly grade based on condition
* @return the converted AnomalyResult instance
*/
public static AnomalyResult fromRawTRCFResult(
Expand All @@ -338,15 +340,20 @@ public static AnomalyResult fromRawTRCFResult(
double[] likelihoodOfValues,
Double threshold,
double[] currentData,
boolean[] featureImputed
boolean[] featureImputed,
List<Rule> rules
) {
List<DataByFeatureId> convertedRelevantAttribution = null;
List<DataByFeatureId> convertedPastValuesList = null;
List<ExpectedValueList> convertedExpectedValues = null;
List<String> featureNamesForComparison = null;

int featureSize = featureData == null ? 0 : featureData.size();

if (grade > 0) {
// Get the top feature names based on the relevant attribution criteria
featureNamesForComparison = getTopFeatureNames(featureData, relevantAttribution);

if (relevantAttribution != null) {
if (relevantAttribution.length == featureSize) {
convertedRelevantAttribution = new ArrayList<>(featureSize);
Expand Down Expand Up @@ -425,6 +432,66 @@ public static AnomalyResult fromRawTRCFResult(
);
}
}
for (String featureName : featureNamesForComparison) {
Double valueToCompare = null;
if (convertedPastValuesList != null) {
Double pastValue = convertedPastValuesList
.stream()
.filter(data -> data.getFeatureId().equals(featureName))
.map(DataByFeatureId::getData)
.findFirst()
.orElse(null);
valueToCompare = pastValue != null ? pastValue : 0d;
} else {
int featureIndex = featureData
.stream()
.filter(data -> data.getFeatureId().equals(featureName))
.map(featureData::indexOf)
.findFirst()
.orElse(-1);

valueToCompare = (featureIndex != -1 && currentData != null) ? currentData[featureIndex] : 0d;
}

Double expectedValue = convertedExpectedValues
.stream()
.flatMap(evList -> evList.getValueList().stream())
.filter(data -> data.getFeatureId().equals(featureName))
.map(DataByFeatureId::getData)
.findFirst()
.orElse(null);

int featureIndex = featureData
.stream()
.filter(data -> data.getFeatureId().equals(featureName))
.map(featureData::indexOf)
.findFirst()
.orElse(-1);

if (valueToCompare == null || expectedValue == null) {
continue; // Skip if either valueToCompare or expectedValue is missing
}

for (Rule rule : rules) {
for (Condition condition : rule.getConditions()) {
if (condition.getFeatureName().equals(featureName)) {
ThresholdType thresholdType = condition.getThresholdType();

if (thresholdType == ThresholdType.ACTUAL_IS_BELOW_EXPECTED && valueToCompare < expectedValue) {
grade = 0d;
break;
} else if (thresholdType == ThresholdType.ACTUAL_IS_OVER_EXPECTED && valueToCompare > expectedValue) {
grade = 0d;
break;
}
}
}
if (grade == 0)
break;
}
if (grade == 0)
break;
}
}

List<FeatureImputed> featureImputedList = new ArrayList<>();
Expand Down Expand Up @@ -468,6 +535,31 @@ public static AnomalyResult fromRawTRCFResult(
);
}

private static List<String> getTopFeatureNames(List<FeatureData> featureData, double[] relevantAttribution) {
List<String> topFeatureNames = new ArrayList<>();

if (relevantAttribution == null || relevantAttribution.length == 0 || (relevantAttribution.length != featureData.size())) {
featureData.forEach(feature -> topFeatureNames.add(feature.getFeatureId()));
return topFeatureNames;
}

// Find the maximum rounded value in a single pass and add corresponding feature names
double maxRoundedAttribution = Arrays
.stream(relevantAttribution)
.map(value -> Math.round(value * 100.0) / 100.0)
.max()
.orElse(Double.NaN);

// Collect feature names with values that match the max rounded value
for (int i = 0; i < relevantAttribution.length; i++) {
if (Math.round(relevantAttribution[i] * 100.0) / 100.0 == maxRoundedAttribution) {
topFeatureNames.add(featureData.get(i).getFeatureId());
}
}

return topFeatureNames;
}

public AnomalyResult(StreamInput input) throws IOException {
super(input);
this.modelId = input.readOptionalString();
Expand Down
35 changes: 24 additions & 11 deletions src/main/java/org/opensearch/ad/model/Condition.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ public class Condition implements Writeable, ToXContentObject {
private String featureName;
private ThresholdType thresholdType;
private Operator operator;
private double value;
private Double value;

public Condition(String featureName, ThresholdType thresholdType, Operator operator, double value) {
public Condition(String featureName, ThresholdType thresholdType, Operator operator, Double value) {
this.featureName = featureName;
this.thresholdType = thresholdType;
this.operator = operator;
Expand All @@ -42,7 +42,7 @@ public Condition(StreamInput input) throws IOException {
this.featureName = input.readString();
this.thresholdType = input.readEnum(ThresholdType.class);
this.operator = input.readEnum(Operator.class);
this.value = input.readDouble();
this.value = input.readBoolean() ? input.readDouble() : null;
}

/**
Expand All @@ -56,7 +56,7 @@ public static Condition parse(XContentParser parser) throws IOException {
String featureName = null;
ThresholdType thresholdType = null;
Operator operator = null;
Double value = 0d;
Double value = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -70,11 +70,19 @@ public static Condition parse(XContentParser parser) throws IOException {
case THRESHOLD_TYPE_FIELD:
thresholdType = ThresholdType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case OPERATOR_FIELD:
operator = Operator.valueOf(parser.text().toUpperCase(Locale.ROOT));
case "operator":
if (parser.currentToken() == XContentParser.Token.VALUE_NULL) {
operator = null; // Set operator to null if the field is missing
} else {
operator = Operator.valueOf(parser.text().toUpperCase(Locale.ROOT));
}
break;
case VALUE_FIELD:
value = parser.doubleValue();
if (parser.currentToken() == XContentParser.Token.VALUE_NULL) {
value = null;
} else {
value = parser.doubleValue();
}
break;
default:
break;
Expand All @@ -89,8 +97,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.startObject()
.field(FEATURE_NAME_FIELD, featureName)
.field(THRESHOLD_TYPE_FIELD, thresholdType)
.field(OPERATOR_FIELD, operator)
.field(VALUE_FIELD, value);
.field(OPERATOR_FIELD, operator);
if (value != null) {
builder.field("value", value);
}
return xContentBuilder.endObject();
}

Expand All @@ -99,7 +109,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(featureName);
out.writeEnum(thresholdType);
out.writeEnum(operator);
out.writeDouble(value);
out.writeBoolean(value != null);
if (value != null) {
out.writeDouble(value);
}
}

public String getFeatureName() {
Expand All @@ -114,7 +127,7 @@ public Operator getOperator() {
return operator;
}

public double getValue() {
public Double getValue() {
return value;
}

Expand Down
14 changes: 13 additions & 1 deletion src/main/java/org/opensearch/ad/model/ThresholdType.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,19 @@ public enum ThresholdType {
* should be ignored if the ratio of the deviation from the expected to the actual
* (b-a)/|a| is less than or equal to ignoreNearExpectedFromBelowByRatio.
*/
EXPECTED_OVER_ACTUAL_RATIO("the ratio of the expected value over the actual value");
EXPECTED_OVER_ACTUAL_RATIO("the ratio of the expected value over the actual value"),

/**
* Specifies a threshold for ignoring anomalies based on whether the actual value
* is over the expected value returned from the model.
*/
ACTUAL_IS_OVER_EXPECTED("the actual value is over the expected value"),

/**
* Specifies a threshold for ignoring anomalies based on whether the actual value
* is below the expected value returned from the model.
* */
ACTUAL_IS_BELOW_EXPECTED("the actual value is below the expected value");

private final String description;

Expand Down
Loading

0 comments on commit a95ab17

Please sign in to comment.