From 9f02334caf0f9b27057b6746834170d2d14b605f Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 10 Oct 2024 20:53:07 -0700 Subject: [PATCH] Bump RCF Version and Fix Default Rules Bug in AnomalyDetector (#1334) * Updated RCF version to the latest release. * Fixed a bug in AnomalyDetector where default rules were not applied when the user provided an empty ruleset. Testing: * Added unit tests to cover the bug fix Signed-off-by: Kaituo Li --- build.gradle | 9 +- .../opensearch/ad/model/AnomalyDetector.java | 2 +- .../opensearch/timeseries/JobProcessor.java | 8 +- .../timeseries/ratelimit/ColdStartWorker.java | 21 +- .../rest/handler/IndexJobActionHandler.java | 2 +- .../handler/ResultBulkIndexingHandler.java | 2 +- .../ad/e2e/AbstractRuleTestCase.java | 2 +- .../ad/e2e/RealTimeRuleModelPerfIT.java | 2 +- .../opensearch/ad/ml/ADColdStartTests.java | 63 ++++++ .../ad/rest/AnomalyDetectorRestApiIT.java | 13 ++ .../ad/task/ADTaskManagerTests.java | 29 +++ .../AnomalyResultBulkIndexHandlerTests.java | 126 +++++++++++ .../opensearch/timeseries/TestHelpers.java | 2 +- .../timeseries/transport/JobRequestTests.java | 197 ++++++++++++++++++ 14 files changed, 451 insertions(+), 27 deletions(-) create mode 100644 src/test/java/org/opensearch/ad/ml/ADColdStartTests.java create mode 100644 src/test/java/org/opensearch/timeseries/transport/JobRequestTests.java diff --git a/build.gradle b/build.gradle index 5ee288da3..416c52896 100644 --- a/build.gradle +++ b/build.gradle @@ -126,9 +126,9 @@ dependencies { implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2' implementation group: 'commons-lang', name: 'commons-lang', version: '2.6' implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0' - implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.1.0' - implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.1.0' - implementation 'software.amazon.randomcutforest:randomcutforest-core:4.1.0' + implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.2.0' + implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.2.0' + implementation 'software.amazon.randomcutforest:randomcutforest-core:4.2.0' // we inherit jackson-core from opensearch core implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1" @@ -700,9 +700,6 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler', - 'org.opensearch.timeseries.transport.SingleStreamResultRequest', - 'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1', 'org.opensearch.timeseries.transport.SuggestConfigParamResponse', 'org.opensearch.timeseries.transport.SuggestConfigParamRequest', 'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap', diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index c8ba4a685..9b057d000 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -233,7 +233,7 @@ public AnomalyDetector( this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name(); - this.rules = rules == null ? getDefaultRule() : rules; + this.rules = rules == null || rules.isEmpty() ? getDefaultRule() : rules; } /* diff --git a/src/main/java/org/opensearch/timeseries/JobProcessor.java b/src/main/java/org/opensearch/timeseries/JobProcessor.java index f9b4863e9..4900fc916 100644 --- a/src/main/java/org/opensearch/timeseries/JobProcessor.java +++ b/src/main/java/org/opensearch/timeseries/JobProcessor.java @@ -200,7 +200,7 @@ public void process(Job jobParameter, JobExecutionContext context) { * @param executionStartTime analysis start time * @param executionEndTime analysis end time * @param recorder utility to record job execution result - * @param detector associated detector accessor + * @param config associated config accessor */ public void runJob( Job jobParameter, @@ -209,7 +209,7 @@ public void runJob( Instant executionStartTime, Instant executionEndTime, ExecuteResultResponseRecorderType recorder, - Config detector + Config config ) { String configId = jobParameter.getName(); if (lock == null) { @@ -222,7 +222,7 @@ public void runJob( "Can't run job due to null lock", false, recorder, - detector + config ); return; } @@ -243,7 +243,7 @@ public void runJob( user, roles, recorder, - detector + config ); } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java index aa6df3d7a..a9d7f1b8c 100644 --- a/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java @@ -163,19 +163,18 @@ protected void executeRequest(FeatureRequest coldStartRequest, ActionListener stopConfigListener( + public ActionListener stopConfigListener( String configId, TransportService transportService, ActionListener listener diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java index 5a4c94a5c..2dddaa475 100644 --- a/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java @@ -145,7 +145,7 @@ public void bulk(String resultIndexOrAlias, List results, String con } catch (Exception e) { String error = "Failed to bulk index result"; LOG.error(error, e); - listener.onFailure(new TimeSeriesException(error, e)); + listener.onFailure(new TimeSeriesException(configId, error, e)); } } diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java index 5aa931cdd..8bb73b147 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java @@ -84,7 +84,7 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT if (relative) { thresholdType1 = "actual_over_expected_ratio"; thresholdType2 = "expected_over_actual_ratio"; - value = 0.3; + value = 0.2; } else { thresholdType1 = "actual_over_expected_margin"; thresholdType2 = "expected_over_actual_margin"; diff --git a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java index 5062fe63c..1ca3a3b8f 100644 --- a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleModelPerfIT.java @@ -36,7 +36,7 @@ public void testRule() throws Exception { minPrecision.put("Scottsdale", 0.5); Map minRecall = new HashMap<>(); minRecall.put("Phoenix", 0.9); - minRecall.put("Scottsdale", 0.6); + minRecall.put("Scottsdale", 0.3); verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20); } } diff --git a/src/test/java/org/opensearch/ad/ml/ADColdStartTests.java b/src/test/java/org/opensearch/ad/ml/ADColdStartTests.java new file mode 100644 index 000000000..1c7c66bef --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/ADColdStartTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.ml; + +import java.io.IOException; +import java.util.ArrayList; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADColdStartTests extends OpenSearchTestCase { + private int baseDimensions = 1; + private int shingleSize = 8; + private int dimensions; + + @Override + public void setUp() throws Exception { + super.setUp(); + dimensions = baseDimensions * shingleSize; + } + + /** + * Test if no explicit rule is provided, we apply 20% rule. + * @throws IOException when failing to constructor detector + */ + public void testEmptyRule() throws IOException { + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(new ArrayList<>()).build(); + ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>() + .dimensions(dimensions) + .shingleSize(shingleSize); + ADColdStart.applyRule(builder, detector); + + ThresholdedRandomCutForest forest = builder.build(); + double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected(); + + // Specify a small delta for floating-point comparison + double delta = 1e-6; + + assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta); + } + + public void testNullRule() throws IOException { + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build(); + ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>() + .dimensions(dimensions) + .shingleSize(shingleSize); + ADColdStart.applyRule(builder, detector); + + ThresholdedRandomCutForest forest = builder.build(); + double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected(); + + // Specify a small delta for floating-point comparison + double delta = 1e-6; + + assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index c47638325..1a2007e69 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -607,6 +607,19 @@ public void testStatsAnomalyDetector() throws Exception { .makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null); assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse)); + + statsResponse = TestHelpers + .makeRequest( + client(), + "GET", + TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + + "/_local/stats/ad_execute_request_count,anomaly_detectors_index_status,ad_hc_execute_request_count,ad_hc_execute_failure_count,ad_execute_failure_count,models_checkpoint_index_status,anomaly_results_index_status", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse)); } public void testPreviewAnomalyDetector() throws Exception { diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index faad44cce..308a0bc90 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -61,6 +61,7 @@ import org.apache.lucene.search.TotalHits; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.opensearch.OpenSearchStatusException; import org.opensearch.Version; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.DocWriteResponse; @@ -104,6 +105,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.IndexNotFoundException; @@ -136,6 +138,7 @@ import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.transport.StatsNodeResponse; import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportResponseHandler; @@ -1544,4 +1547,30 @@ public void testDeleteTaskDocs() { verify(adTaskCacheManager, times(1)).addDeletedTask(anyString()); verify(function, times(1)).execute(); } + + public void testStopConfigListener_onResponse_failure() { + // Arrange + String configId = randomAlphaOfLength(5); + TransportService transportService = mock(TransportService.class); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Act + ActionListener stopConfigListener = indexAnomalyDetectorJobActionHandler + .stopConfigListener(configId, transportService, listener); + StopConfigResponse stopConfigResponse = mock(StopConfigResponse.class); + when(stopConfigResponse.success()).thenReturn(false); + + stopConfigListener.onResponse(stopConfigResponse); + + // Assert + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + + verify(adTaskManager, times(1)) + .stopLatestRealtimeTask(eq(configId), eq(TaskState.FAILED), exceptionCaptor.capture(), eq(transportService), eq(listener)); + + OpenSearchStatusException capturedException = exceptionCaptor.getValue(); + assertEquals("Failed to delete model", capturedException.getMessage()); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, capturedException.status()); + } } diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index 98daeb1d9..d7cbd9817 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -27,6 +27,7 @@ import java.time.Clock; import java.util.Optional; +import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.bulk.BulkAction; @@ -43,11 +44,13 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.IndexUtils; @@ -232,4 +235,127 @@ private AnomalyResult wrongAnomalyResult() { null ); } + + public void testResponseIsAcknowledgedTrue() throws InterruptedException { + String testIndex = "testIndex"; + + // Set up mocks for doesIndexExist and doesAliasExist + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false); + + // Mock initCustomResultIndexDirectly to simulate index creation and call the listener + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + // Simulate immediate onResponse call + listener.onResponse(new CreateIndexResponse(true, true, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + AnomalyResult result = mock(AnomalyResult.class); + + // Call bulk method + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + // Verify that listener.onResponse is called + verify(client, times(1)).prepareBulk(); + } + + public void testResponseIsAcknowledgedFalse() { + String testIndex = "testIndex"; + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(false, false, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + AnomalyResult result = mock(AnomalyResult.class); + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Creating custom result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage()); + } + + public void testResourceAlreadyExistsException() { + String testIndex = "testIndex"; + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false, true); + when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false, false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new ResourceAlreadyExistsException("index already exists")); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(true); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + AnomalyResult result = mock(AnomalyResult.class); + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + // Verify that listener.onResponse is called + verify(client, times(1)).prepareBulk(); + } + + public void testOtherException() { + String testIndex = "testIndex"; + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false); + + Exception testException = new OpenSearchRejectedExecutionException("Test exception"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(testException); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + AnomalyResult result = mock(AnomalyResult.class); + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals(testException, exceptionCaptor.getValue()); + } + + public void testTimeSeriesExceptionCaughtInBulk() { + String testIndex = "testIndex"; + TimeSeriesException testException = new TimeSeriesException("Test TimeSeriesException"); + + // Mock doesIndexExist to throw TimeSeriesException + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException); + + AnomalyResult result = mock(AnomalyResult.class); + + // Call bulk method + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + // Verify that listener.onFailure is called with the TimeSeriesException + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals(testException, exceptionCaptor.getValue()); + } + + public void testExceptionCaughtInBulk() { + String testIndex = "testIndex"; + NullPointerException testException = new NullPointerException("Test NullPointerException"); + + // Mock doesIndexExist to throw NullPointerException + when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException); + + AnomalyResult result = mock(AnomalyResult.class); + + // Call bulk method + bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener); + + // Verify that listener.onFailure is called with a TimeSeriesException wrapping the original exception + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + Exception capturedException = exceptionCaptor.getValue(); + assertTrue(capturedException instanceof TimeSeriesException); + assertEquals("Failed to bulk index result", capturedException.getMessage()); + assertEquals(testException, capturedException.getCause()); + } } diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 22dcf64bd..7e4a0b7d0 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -755,7 +755,7 @@ public AnomalyDetector build() { // as ModelColdStart.selectNumberOfSamples will select the smaller of // 32 and historical intervals. randomIntBetween(TimeSeriesSettings.NUM_MIN_SAMPLES, 1000), - null, + rules, null, null, null, diff --git a/src/test/java/org/opensearch/timeseries/transport/JobRequestTests.java b/src/test/java/org/opensearch/timeseries/transport/JobRequestTests.java new file mode 100644 index 000000000..6487f0785 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/JobRequestTests.java @@ -0,0 +1,197 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.constant.CommonName; + +public class JobRequestTests extends OpenSearchTestCase { + public void testSerializationDeserialization() throws IOException { + String configId = "test-config-id"; + String modelId = "test-model-id"; + long startMillis = 1622548800000L; // June 1, 2021 00:00:00 GMT + long endMillis = 1622635200000L; // June 2, 2021 00:00:00 GMT + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = "test-task-id"; + + // Create the original request + SingleStreamResultRequest originalRequest = new SingleStreamResultRequest( + configId, + modelId, + startMillis, + endMillis, + datapoint, + taskId + ); + + // Serialize the request to a BytesStreamOutput + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + // Deserialize the request from the StreamInput + StreamInput in = out.bytes().streamInput(); + SingleStreamResultRequest deserializedRequest = new SingleStreamResultRequest(in); + + // Assert that the deserialized request matches the original + assertEquals(originalRequest.getConfigId(), deserializedRequest.getConfigId()); + assertEquals(originalRequest.getModelId(), deserializedRequest.getModelId()); + assertEquals(originalRequest.getStart(), deserializedRequest.getStart()); + assertEquals(originalRequest.getEnd(), deserializedRequest.getEnd()); + assertArrayEquals(originalRequest.getDataPoint(), deserializedRequest.getDataPoint(), 0.0001); + assertEquals(originalRequest.getTaskId(), deserializedRequest.getTaskId()); + } + + public void testSerializationDeserialization_NullTaskId() throws IOException { + String configId = "test-config-id"; + String modelId = "test-model-id"; + long startMillis = 1622548800000L; + long endMillis = 1622635200000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = null; + + SingleStreamResultRequest originalRequest = new SingleStreamResultRequest( + configId, + modelId, + startMillis, + endMillis, + datapoint, + taskId + ); + + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + SingleStreamResultRequest deserializedRequest = new SingleStreamResultRequest(in); + + assertEquals(originalRequest.getConfigId(), deserializedRequest.getConfigId()); + assertEquals(originalRequest.getModelId(), deserializedRequest.getModelId()); + assertEquals(originalRequest.getStart(), deserializedRequest.getStart()); + assertEquals(originalRequest.getEnd(), deserializedRequest.getEnd()); + assertArrayEquals(originalRequest.getDataPoint(), deserializedRequest.getDataPoint(), 0.0001); + assertNull(deserializedRequest.getTaskId()); + } + + public void testToXContent() throws IOException { + String configId = "test-config-id"; + String modelId = "test-model-id"; + long startMillis = 1622548800000L; + long endMillis = 1622635200000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = "test-task-id"; + + SingleStreamResultRequest request = new SingleStreamResultRequest(configId, modelId, startMillis, endMillis, datapoint, taskId); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + request.toXContent(builder, null); + String jsonString = builder.toString(); + + String expectedJson = String + .format( + Locale.ROOT, + "{\"%s\":\"%s\",\"%s\":\"%s\",\"%s\":%d,\"%s\":%d,\"%s\":[1.0,2.0,3.0],\"%s\":\"%s\"}", + CommonName.CONFIG_ID_KEY, + configId, + CommonName.MODEL_ID_KEY, + modelId, + CommonName.START_JSON_KEY, + startMillis, + CommonName.END_JSON_KEY, + endMillis, + CommonName.VALUE_LIST_FIELD, + CommonName.RUN_ONCE_FIELD, + taskId + ); + + assertEquals(expectedJson, jsonString); + } + + public void testToXContent_NullTaskId() throws IOException { + String configId = "test-config-id"; + String modelId = "test-model-id"; + long startMillis = 1622548800000L; + long endMillis = 1622635200000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = null; + + SingleStreamResultRequest request = new SingleStreamResultRequest(configId, modelId, startMillis, endMillis, datapoint, taskId); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + request.toXContent(builder, null); + String jsonString = builder.toString(); + + String expectedJson = String + .format( + Locale.ROOT, + "{\"%s\":\"%s\",\"%s\":\"%s\",\"%s\":%d,\"%s\":%d,\"%s\":[1.0,2.0,3.0],\"%s\":null}", + CommonName.CONFIG_ID_KEY, + configId, + CommonName.MODEL_ID_KEY, + modelId, + CommonName.START_JSON_KEY, + startMillis, + CommonName.END_JSON_KEY, + endMillis, + CommonName.VALUE_LIST_FIELD, + CommonName.RUN_ONCE_FIELD + ); + + assertEquals(expectedJson, jsonString); + } + + public void testValidate_MissingConfigId() { + String configId = null; // Missing configId + String modelId = "test-model-id"; + long startMillis = 1622548800000L; + long endMillis = 1622635200000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = "test-task-id"; + + SingleStreamResultRequest request = new SingleStreamResultRequest(configId, modelId, startMillis, endMillis, datapoint, taskId); + + ActionRequestValidationException validationException = request.validate(); + assertNotNull(validationException); + assertTrue("actual: " + validationException.getMessage(), validationException.getMessage().contains("config ID is missing")); + } + + public void testValidate_MissingModelId() { + String configId = "test-config-id"; + String modelId = null; // Missing modelId + long startMillis = 1622548800000L; + long endMillis = 1622635200000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = "test-task-id"; + + SingleStreamResultRequest request = new SingleStreamResultRequest(configId, modelId, startMillis, endMillis, datapoint, taskId); + + ActionRequestValidationException validationException = request.validate(); + assertNotNull(validationException); + assertTrue("actual: " + validationException.getMessage(), validationException.getMessage().contains("model ID is missing")); + } + + public void testValidate_InvalidTimestamps() { + String configId = "test-config-id"; + String modelId = "test-model-id"; + long startMillis = 1622635200000L; // End time before start time + long endMillis = 1622548800000L; + double[] datapoint = new double[] { 1.0, 2.0, 3.0 }; + String taskId = "test-task-id"; + + SingleStreamResultRequest request = new SingleStreamResultRequest(configId, modelId, startMillis, endMillis, datapoint, taskId); + + ActionRequestValidationException validationException = request.validate(); + assertNotNull(validationException); + assertTrue("actual: " + validationException.getMessage(), validationException.getMessage().contains("timestamp is invalid")); + } +}