diff --git a/build.gradle b/build.gradle index 4addff425..ca10d416d 100644 --- a/build.gradle +++ b/build.gradle @@ -699,9 +699,6 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.transport.SuggestConfigParamResponse', - 'org.opensearch.timeseries.transport.SuggestConfigParamRequest', - 'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap', 'org.opensearch.timeseries.transport.ResultBulkTransportAction', 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', diff --git a/src/main/java/org/opensearch/ad/ml/ADModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java index 354b02557..a8f0febd9 100644 --- a/src/main/java/org/opensearch/ad/ml/ADModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -15,7 +15,6 @@ import java.time.Duration; import java.time.Instant; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -42,7 +41,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; -import org.opensearch.timeseries.AnalysisModelSize; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -52,7 +50,6 @@ import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.ml.ModelState; -import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DateUtils; @@ -69,9 +66,7 @@ * A facade managing ML operations and models. */ public class ADModelManager extends - ModelManager - implements - AnalysisModelSize { + ModelManager { protected static final String ENTITY_SAMPLE = "sp"; protected static final String ENTITY_RCF = "rcf"; protected static final String ENTITY_THRESHOLD = "th"; @@ -594,25 +589,6 @@ public List getPreviewResults(Features features, AnomalyDete }).collect(Collectors.toList()); } - /** - * Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB). - * @param detectorId detector id - * @return a map of model id to its memory size - */ - @Override - public Map getModelSize(String detectorId) { - Map res = new HashMap<>(); - res.putAll(forests.getModelSize(detectorId)); - thresholds - .entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId)) - .forEach(entry -> { - res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); - }); - return res; - } - /** * Get a RCF model's total updates. * @param modelId the RCF model's id diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index 9b057d000..2572299b1 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -22,6 +22,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,6 +110,7 @@ public Integer getShingleSize(Integer customShingleSize) { @Deprecated public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; public static final String RULES_FIELD = "rules"; + private static final String SUPPRESSION_RULE_ISSUE_PREFIX = "Suppression Rule Error: "; protected String detectorType; @@ -229,6 +231,8 @@ public AnomalyDetector( issueType = ValidationIssueType.CATEGORY; } + validateRules(features, rules); + checkAndThrowValidationErrors(ValidationAspect.DETECTOR); this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name(); @@ -720,4 +724,121 @@ private static Boolean onlyParseBooleanValue(XContentParser parser) throws IOExc } return null; } + + /** + * Validates each condition in the list of rules against the list of features. + * Checks that: + * - The feature name exists in the list of features. + * - The related feature is enabled. + * - The value is not NaN and is positive. + * + * @param features The list of available features. Must not be null. + * @param rules The list of rules containing conditions to validate. Can be null. + */ + private void validateRules(List features, List rules) { + // Null check for rules + if (rules == null || rules.isEmpty()) { + return; // No suppression rules to validate; consider as valid + } + + // Null check for features + if (features == null || features.isEmpty()) { + // Cannot proceed with validation if features are null but rules are not null + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "Features are not defined while suppression rules are provided."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Create a map of feature names to their enabled status for quick lookup + Map featureEnabledMap = new HashMap<>(); + for (Feature feature : features) { + if (feature != null && feature.getName() != null) { + featureEnabledMap.put(feature.getName(), feature.getEnabled()); + } + } + + // Iterate over each rule + for (Rule rule : rules) { + if (rule == null || rule.getConditions() == null) { + // Invalid rule or conditions list is null + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A suppression rule or its conditions are not properly defined."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Iterate over each condition in the rule + for (Condition condition : rule.getConditions()) { + if (condition == null) { + // Invalid condition + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition within a suppression rule is not properly defined."; + this.issueType = ValidationIssueType.RULE; + return; + } + + String featureName = condition.getFeatureName(); + + // Check if the feature name is null + if (featureName == null) { + // Feature name is required + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition is missing the feature name."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the feature exists + if (!featureEnabledMap.containsKey(featureName)) { + // Feature does not exist + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "Feature \"" + + featureName + + "\" specified in a suppression rule does not exist."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the feature is enabled + if (!featureEnabledMap.get(featureName)) { + // Feature is not enabled + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "Feature \"" + + featureName + + "\" specified in a suppression rule is not enabled."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // other threshold types may not have value operand + ThresholdType thresholdType = condition.getThresholdType(); + if (thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_MARGIN + || thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_MARGIN + || thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_RATIO + || thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_RATIO) { + // Check if the value is not NaN + double value = condition.getValue(); + if (Double.isNaN(value)) { + // Value is NaN + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "The threshold value for feature \"" + + featureName + + "\" is not a valid number."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the value is positive + if (value <= 0) { + // Value is not positive + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "The threshold value for feature \"" + + featureName + + "\" must be a positive number."; + this.issueType = ValidationIssueType.RULE; + return; + } + } + } + } + + // All checks passed + } } diff --git a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java index b477f454a..cc723b5f4 100644 --- a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java +++ b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java @@ -11,9 +11,6 @@ package org.opensearch.timeseries.ml; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.opensearch.timeseries.MemoryTracker; @@ -55,48 +52,4 @@ public ModelState put(String key, ModelState value) } return previousAssociatedState; } - - /** - * Gets all of a config's model sizes hosted on a node - * - * @param configId config Id - * @return a map of model id to its memory size - */ - public Map getModelSize(String configId) { - Map res = new HashMap<>(); - super.entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) - .forEach(entry -> { - Optional modelOptional = entry.getValue().getModel(); - if (modelOptional.isPresent()) { - res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get())); - } - }); - return res; - } - - /** - * Checks if a model exists for the given config. - * @param configId Config Id - * @return `true` if the model exists, `false` otherwise. - */ - public boolean doesModelExist(String configId) { - return super.entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) - .anyMatch(n -> true); - } - - public boolean hostIfPossible(String modelId, ModelState toUpdate) { - return Optional - .ofNullable(toUpdate) - .filter(state -> state.getModel().isPresent()) - .filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get())) - .map(state -> { - super.put(modelId, toUpdate); - return true; - }) - .orElse(false); - } } diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java index bd4a86cee..55d039eb4 100644 --- a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java +++ b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java @@ -38,7 +38,8 @@ public enum ValidationIssueType implements Name { SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD), RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD), DESCRIPTION(Config.DESCRIPTION_FIELD), - HISTORY(Config.HISTORY_INTERVAL_FIELD); + HISTORY(Config.HISTORY_INTERVAL_FIELD), + RULE(AnomalyDetector.RULES_FIELD); private String name; diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java index 3c7b9f45a..ee17f163c 100644 --- a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java @@ -33,9 +33,9 @@ public class SuggestConfigParamRequest extends ActionRequest { public SuggestConfigParamRequest(StreamInput in) throws IOException { super(in); context = in.readEnum(AnalysisType.class); - if (context.isAD()) { + if (getContext().isAD()) { config = new AnomalyDetector(in); - } else if (context.isForecast()) { + } else if (getContext().isForecast()) { config = new Forecaster(in); } else { throw new UnsupportedOperationException("This method is not supported"); @@ -55,7 +55,7 @@ public SuggestConfigParamRequest(AnalysisType context, Config config, String par @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeEnum(context); + out.writeEnum(getContext()); config.writeTo(out); out.writeString(param); out.writeTimeValue(requestTimeout); @@ -77,4 +77,8 @@ public String getParam() { public TimeValue getRequestTimeout() { return requestTimeout; } + + public AnalysisType getContext() { + return context; + } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index b10c1afa4..902edb949 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -18,6 +18,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; @@ -1047,4 +1048,239 @@ public void testNullFixedValue() throws IOException { assertEquals("Got: " + e.getMessage(), "Enabled features are present, but no default fill values are provided.", e.getMessage()); assertEquals("Got :" + e.getType(), ValidationIssueType.IMPUTATION, e.getType()); } + + /** + * Test that validation passes when rules are null. + */ + public void testValidateRulesWithNullRules() throws IOException { + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build(); + + // Should pass validation; no exception should be thrown + assertNotNull(detector); + } + + /** + * Test that validation fails when features are null but rules are provided. + */ + public void testValidateRulesWithNullFeatures() throws IOException { + List rules = Arrays.asList(createValidRule()); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(0).setFeatureAttributes(null).setRules(rules).build(); + fail("Expected ValidationException due to features being null while rules are provided"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: Features are not defined while suppression rules are provided.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a rule is null. + */ + public void testValidateRulesWithNullRule() throws IOException { + List rules = Arrays.asList((Rule) null); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to null rule"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A suppression rule or its conditions are not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a rule's conditions are null. + */ + public void testValidateRulesWithNullConditions() throws IOException { + Rule rule = new Rule(Action.IGNORE_ANOMALY, null); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to rule with null conditions"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A suppression rule or its conditions are not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition is null. + */ + public void testValidateRulesWithNullCondition() throws IOException { + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList((Condition) null)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to null condition in rule"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A condition within a suppression rule is not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition's featureName is null. + */ + public void testValidateRulesWithNullFeatureName() throws IOException { + Condition condition = new Condition( + null, // featureName is null + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + 0.5 + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to condition with null feature name"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A condition is missing the feature name.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition's featureName does not exist in features. + */ + public void testValidateRulesWithNonexistentFeatureName() throws IOException { + Condition condition = new Condition( + "nonexistentFeature", // featureName not in features + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + 0.5 + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to condition with nonexistent feature name"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: Feature \"nonexistentFeature\" specified in a suppression rule does not exist.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the feature in condition is disabled. + */ + public void testValidateRulesWithDisabledFeature() throws IOException { + String featureName = "testFeature"; + Feature disabledFeature = TestHelpers.randomFeature(featureName, "agg", false); + + Condition condition = new Condition(featureName, ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, Operator.LTE, 0.5); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(disabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to condition with disabled feature"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: Feature \"" + featureName + "\" specified in a suppression rule is not enabled.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the value in condition is NaN for specific threshold types. + */ + public void testValidateRulesWithNaNValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + Double.NaN // Value is NaN + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(enabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to NaN value in condition"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: The threshold value for feature \"" + featureName + "\" is not a valid number.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the value in condition is not positive for specific threshold types. + */ + public void testValidateRulesWithNonPositiveValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + -0.5 // Value is negative + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(enabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to non-positive value in condition"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: The threshold value for feature \"" + featureName + "\" must be a positive number.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation passes when the threshold type is not one of the specified types and value is NaN. + */ + public void testValidateRulesWithOtherThresholdTypeAndNaNValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + null, // ThresholdType is null or another type not specified + Operator.LTE, + Double.NaN // Value is NaN, but should not be checked + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder + .newInstance(1) + .setFeatureAttributes(Arrays.asList(enabledFeature)) + .setRules(rules) + .build(); + + // Should pass validation; no exception should be thrown + assertNotNull(detector); + } + + /** + * Helper method to create a valid rule for testing. + * + * @return A valid Rule instance + */ + private Rule createValidRule() { + String featureName = "testFeature"; + Condition condition = new Condition(featureName, ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, Operator.LTE, 0.5); + return new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + } } diff --git a/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java new file mode 100644 index 000000000..e3c772c38 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TestHelpers; + +public class SuggestConfigParamRequestTests extends OpenSearchTestCase { + private NamedWriteableRegistry registry; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new)); + namedWriteables + .add( + new NamedWriteableRegistry.Entry( + AggregationBuilder.class, + ValueCountAggregationBuilder.NAME, + ValueCountAggregationBuilder::new + ) + ); + registry = new NamedWriteableRegistry(namedWriteables); + } + + /** + * Test serialization and deserialization of SuggestConfigParamRequest with AD context. + */ + public void testSerializationDeserialization_ADContext() throws IOException { + // Create an AnomalyDetector instance + AnomalyDetector detector = createTestAnomalyDetector(); + + AnalysisType context = AnalysisType.AD; + String param = "test-param"; + TimeValue requestTimeout = TimeValue.timeValueSeconds(30); + + SuggestConfigParamRequest originalRequest = new SuggestConfigParamRequest(context, detector, param, requestTimeout); + + // Serialize the request + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + // Deserialize the request + StreamInput in = out.bytes().streamInput(); + + StreamInput input = new NamedWriteableAwareStreamInput(in, registry); + + SuggestConfigParamRequest deserializedRequest = new SuggestConfigParamRequest(input); + + // Verify the deserialized object + assertEquals(context, deserializedRequest.getContext()); + assertTrue(deserializedRequest.getConfig() instanceof AnomalyDetector); + AnomalyDetector deserializedDetector = (AnomalyDetector) deserializedRequest.getConfig(); + assertEquals(detector, deserializedDetector); + assertEquals(param, deserializedRequest.getParam()); + assertEquals(requestTimeout, deserializedRequest.getRequestTimeout()); + } + + /** + * Test serialization and deserialization of SuggestConfigParamRequest with Forecast context. + */ + public void testSerializationDeserialization_ForecastContext() throws IOException { + // Create a Forecaster instance using TestHelpers.ForecasterBuilder + Forecaster forecaster = createTestForecaster(); + + AnalysisType context = AnalysisType.FORECAST; + String param = "test-param"; + TimeValue requestTimeout = TimeValue.timeValueSeconds(30); + + SuggestConfigParamRequest originalRequest = new SuggestConfigParamRequest(context, forecaster, param, requestTimeout); + + // Serialize the request + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + // Deserialize the request + StreamInput in = out.bytes().streamInput(); + StreamInput input = new NamedWriteableAwareStreamInput(in, registry); + + SuggestConfigParamRequest deserializedRequest = new SuggestConfigParamRequest(input); + + // Verify the deserialized object + assertEquals(context, deserializedRequest.getContext()); + assertTrue(deserializedRequest.getConfig() instanceof Forecaster); + Forecaster deserializedForecaster = (Forecaster) deserializedRequest.getConfig(); + assertEquals(forecaster, deserializedForecaster); + assertEquals(param, deserializedRequest.getParam()); + assertEquals(requestTimeout, deserializedRequest.getRequestTimeout()); + } + + // Helper methods to create test instances of AnomalyDetector and Forecaster + + private AnomalyDetector createTestAnomalyDetector() { + // Use TestHelpers.AnomalyDetectorBuilder to create a test AnomalyDetector instance + try { + return TestHelpers.AnomalyDetectorBuilder.newInstance(1).build(); + } catch (IOException e) { + fail("Failed to create test AnomalyDetector: " + e.getMessage()); + return null; + } + } + + private Forecaster createTestForecaster() { + // Use TestHelpers.ForecasterBuilder to create a Forecaster instance + try { + return TestHelpers.ForecasterBuilder.newInstance().build(); + } catch (IOException e) { + fail("Failed to create test Forecaster: " + e.getMessage()); + return null; + } + } +} diff --git a/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java new file mode 100644 index 000000000..7e083731e --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Mergeable; + +public class SuggestConfigParamResponseTests extends OpenSearchTestCase { + + /** + * Test the serialization and deserialization of SuggestConfigParamResponse. + * This covers both the writeTo(StreamOutput out) method and the + * SuggestConfigParamResponse(StreamInput in) constructor. + */ + public void testSerializationDeserialization() throws IOException { + // Create an instance of SuggestConfigParamResponse + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse originalResponse = new SuggestConfigParamResponse(interval, horizon, history); + + // Serialize it to a BytesStreamOutput + BytesStreamOutput out = new BytesStreamOutput(); + originalResponse.writeTo(out); + + // Deserialize it from the StreamInput + StreamInput in = out.bytes().streamInput(); + SuggestConfigParamResponse deserializedResponse = new SuggestConfigParamResponse(in); + + // Assert that the deserialized object matches the original + assertEquals(originalResponse.getInterval(), deserializedResponse.getInterval()); + assertEquals(originalResponse.getHorizon(), deserializedResponse.getHorizon()); + assertEquals(originalResponse.getHistory(), deserializedResponse.getHistory()); + } + + /** + * Test the toXContent(XContentBuilder builder) method. + * This ensures that the response is correctly converted to XContent. + */ + public void testToXContent() throws IOException { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder); + String jsonString = builder.toString(); + + // Expected JSON string contains interval, horizon, history + assertTrue("actual json: " + jsonString, jsonString.contains("\"interval\"")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"interval\":10")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"unit\":\"Minutes\"")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"horizon\":12")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"history\":24")); + } + + /** + * Test the merge(Mergeable other) method when it returns early due to: + * - other being null + * - this being equal to other + * - getClass() != other.getClass() + */ + public void testMerge_ReturnEarly() { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + // Case when other == null + response.merge(null); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + + // Case when this == other + response.merge(response); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + + // Case when getClass() != other.getClass() + Mergeable other = new Mergeable() { + @Override + public void merge(Mergeable other) { + // No operation + } + }; + + response.merge(other); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + } + + /** + * Test the merge(Mergeable other) method when otherProfile.getHistory() != null. + * This ensures that the history field is correctly updated from the other object. + */ + public void testMerge_OtherHasHistory() { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = null; // Initial history is null + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + Integer otherHistory = 30; + + SuggestConfigParamResponse otherResponse = new SuggestConfigParamResponse(null, null, otherHistory); + + // Before merge, response.history is null + assertNull(response.getHistory()); + + // Merge + response.merge(otherResponse); + + // After merge, response.history should be updated + assertEquals(otherHistory, response.getHistory()); + + // Interval and horizon should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + } +}