Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport to 2.x] Bump RCF Version and Fix Default Rules Bug in AnomalyDetector #1336

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.2.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand Down Expand Up @@ -700,9 +700,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public AnomalyDetector(

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();

this.rules = rules == null ? getDefaultRule() : rules;
this.rules = rules == null || rules.isEmpty() ? getDefaultRule() : rules;
}

/*
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/timeseries/JobProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void process(Job jobParameter, JobExecutionContext context) {
* @param executionStartTime analysis start time
* @param executionEndTime analysis end time
* @param recorder utility to record job execution result
* @param detector associated detector accessor
* @param config associated config accessor
*/
public void runJob(
Job jobParameter,
Expand All @@ -209,7 +209,7 @@ public void runJob(
Instant executionStartTime,
Instant executionEndTime,
ExecuteResultResponseRecorderType recorder,
Config detector
Config config
) {
String configId = jobParameter.getName();
if (lock == null) {
Expand All @@ -222,7 +222,7 @@ public void runJob(
"Can't run job due to null lock",
false,
recorder,
detector
config
);
return;
}
Expand All @@ -243,7 +243,7 @@ public void runJob(
user,
roles,
recorder,
detector
config
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,18 @@ protected void executeRequest(FeatureRequest coldStartRequest, ActionListener<Vo
);
IntermediateResultType result = modelManager.getResult(currentSample, modelState, modelId, config, taskId);
resultSaver.saveResult(result, config, coldStartRequest, modelId);
}

// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
}
}

} finally {
listener.onResponse(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ public void stopJob(String configId, TransportService transportService, ActionLi
}));
}

private ActionListener<StopConfigResponse> stopConfigListener(
public ActionListener<StopConfigResponse> stopConfigListener(
String configId,
TransportService transportService,
ActionListener<JobResponse> listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void bulk(String resultIndexOrAlias, List<ResultType> results, String con
} catch (Exception e) {
String error = "Failed to bulk index result";
LOG.error(error, e);
listener.onFailure(new TimeSeriesException(error, e));
listener.onFailure(new TimeSeriesException(configId, error, e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT
if (relative) {
thresholdType1 = "actual_over_expected_ratio";
thresholdType2 = "expected_over_actual_ratio";
value = 0.3;
value = 0.2;
} else {
thresholdType1 = "actual_over_expected_margin";
thresholdType2 = "expected_over_actual_margin";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/org/opensearch/ad/ml/ADColdStartTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ad.ml;

import java.io.IOException;
import java.util.ArrayList;

import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.timeseries.TestHelpers;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;

public class ADColdStartTests extends OpenSearchTestCase {
private int baseDimensions = 1;
private int shingleSize = 8;
private int dimensions;

@Override
public void setUp() throws Exception {
super.setUp();
dimensions = baseDimensions * shingleSize;
}

/**
* Test if no explicit rule is provided, we apply 20% rule.
* @throws IOException when failing to constructor detector
*/
public void testEmptyRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(new ArrayList<>()).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}

public void testNullRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,19 @@ public void testStatsAnomalyDetector() throws Exception {
.makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null);

assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse));

statsResponse = TestHelpers
.makeRequest(
client(),
"GET",
TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE
+ "/_local/stats/ad_execute_request_count,anomaly_detectors_index_status,ad_hc_execute_request_count,ad_hc_execute_failure_count,ad_execute_failure_count,models_checkpoint_index_status,anomaly_results_index_status",
ImmutableMap.of(),
"",
null
);

assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse));
}

public void testPreviewAnomalyDetector() throws Exception {
Expand Down
29 changes: 29 additions & 0 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.lucene.search.TotalHits;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.Version;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
Expand Down Expand Up @@ -104,6 +105,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -136,6 +138,7 @@
import org.opensearch.timeseries.transport.JobResponse;
import org.opensearch.timeseries.transport.StatsNodeResponse;
import org.opensearch.timeseries.transport.StatsNodesResponse;
import org.opensearch.timeseries.transport.StopConfigResponse;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;
import org.opensearch.transport.TransportResponseHandler;
Expand Down Expand Up @@ -1544,4 +1547,30 @@ public void testDeleteTaskDocs() {
verify(adTaskCacheManager, times(1)).addDeletedTask(anyString());
verify(function, times(1)).execute();
}

public void testStopConfigListener_onResponse_failure() {
// Arrange
String configId = randomAlphaOfLength(5);
TransportService transportService = mock(TransportService.class);
@SuppressWarnings("unchecked")
ActionListener<JobResponse> listener = mock(ActionListener.class);

// Act
ActionListener<StopConfigResponse> stopConfigListener = indexAnomalyDetectorJobActionHandler
.stopConfigListener(configId, transportService, listener);
StopConfigResponse stopConfigResponse = mock(StopConfigResponse.class);
when(stopConfigResponse.success()).thenReturn(false);

stopConfigListener.onResponse(stopConfigResponse);

// Assert
ArgumentCaptor<OpenSearchStatusException> exceptionCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);

verify(adTaskManager, times(1))
.stopLatestRealtimeTask(eq(configId), eq(TaskState.FAILED), exceptionCaptor.capture(), eq(transportService), eq(listener));

OpenSearchStatusException capturedException = exceptionCaptor.getValue();
assertEquals("Failed to delete model", capturedException.getMessage());
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, capturedException.status());
}
}
Loading
Loading