Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix that AD job cannot be terminated due to missing training data
Browse files Browse the repository at this point in the history
We expect an EndRunException to be thrown due to missing training data.  But the exception is not appropriately propagated back to AD job and results in InternalFailure instead.  This PR fixes the bug.

Testing done:
1. Manually reproduced all  possible EndRunExceptions, check AD job is terminated, and check profile API status is correct.
2. Added unit tests to expose the bug.
  • Loading branch information
kaituo committed May 4, 2020
1 parent fbc8a4e commit cf48892
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ protected void handleAdException(
detectorEndRunExceptionCount.remove(detectorId);
if (exception instanceof InternalFailure) {
// AnomalyResultTransportAction already prints exception stack trace
log.error("InternalFailure happened when executed anomaly result action for " + detectorId);
log.error("InternalFailure happened when executing anomaly result action for " + detectorId);
} else {
log.error("Failed to execute anomaly result action for " + detectorId, exception);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ private boolean isHostingAllowed(String detectorId, RandomCutForest rcf) {
} else {
throw new LimitExceededException(
detectorId,
String.format("Exceeded memory limit. New size is %d and max limit is %f", total, heapLimit)
String.format("Exceeded memory limit. New size is %d bytes and max limit is %f bytes", total, heapLimit)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,7 @@ private void findException(Throwable cause, String adID, AtomicReference<Anomaly
if (isException(causeException, ResourceNotFoundException.class, RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE)
|| (causeException instanceof IndexNotFoundException
&& causeException.getMessage().contains(CommonName.CHECKPOINT_INDEX_NAME))) {
// fetch previous cold start exception
Optional<? extends AnomalyDetectionException> previousException = globalRunner.fetchException(adID);

if (previousException.isPresent()) {
LOG.error("Previous exception of {}: {}", () -> adID, () -> previousException.get());
failure.set(previousException.get());
} else {
failure.set(new ResourceNotFoundException(adID, causeException.getMessage()));
}
failure.set(new ResourceNotFoundException(adID, causeException.getMessage()));
} else if (isException(causeException, LimitExceededException.class, LIMIT_EXCEEDED_EXCEPTION_NAME_UNDERSCORE)) {
failure.set(new LimitExceededException(adID, causeException.getMessage()));
} else if (causeException instanceof ElasticsearchTimeoutException) {
Expand Down Expand Up @@ -553,7 +545,7 @@ public void onResponse(RCFResultResponse response) {
@Override
public void onFailure(Exception e) {
try {
handlePredictionFailure(e, modelID, rcfNodeID, failure);
handlePredictionFailure(e, adID, rcfNodeID, failure);
} catch (Exception ex) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
Expand All @@ -565,7 +557,20 @@ public void onFailure(Exception e) {

private void handleRCFResults() {
try {
if (coldStartIfNoModel(failure, detector) || rcfResults.isEmpty()) {
if (coldStartIfNoModel(failure, detector)) {
// fetch previous cold start exception
Optional<? extends AnomalyDetectionException> previousException = globalRunner.fetchException(adID);

if (previousException.isPresent()) {
LOG.error("Previous exception of {}: {}", () -> adID, () -> previousException.get());
listener.onFailure(previousException.get());
} else {
listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG));
}
return;
}

if (rcfResults.isEmpty()) {
listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG));
return;
}
Expand All @@ -580,7 +585,6 @@ private void handleRCFResults() {
ThresholdActionListener thresholdListener = new ThresholdActionListener(
anomalyResultResponse,
featureInResponse,
thresholdModelID,
thresholdNodeId,
detector,
combinedResult,
Expand All @@ -604,7 +608,6 @@ private void handleRCFResults() {
class ThresholdActionListener implements ActionListener<ThresholdResultResponse> {
private AtomicReference<AnomalyResultResponse> anomalyResultResponse;
private List<FeatureData> features;
private String modelID;
private AtomicReference<AnomalyDetectionException> failure;
private String thresholdNodeID;
private ActionListener<AnomalyResultResponse> listener;
Expand All @@ -615,7 +618,6 @@ class ThresholdActionListener implements ActionListener<ThresholdResultResponse>
ThresholdActionListener(
AtomicReference<AnomalyResultResponse> anomalyResultResponse,
List<FeatureData> features,
String modelID,
String thresholdNodeID,
AnomalyDetector detector,
CombinedRcfResult combinedResult,
Expand All @@ -624,7 +626,6 @@ class ThresholdActionListener implements ActionListener<ThresholdResultResponse>
) {
this.anomalyResultResponse = anomalyResultResponse;
this.features = features;
this.modelID = modelID;
this.thresholdNodeID = thresholdNodeID;
this.detector = detector;
this.combinedResult = combinedResult;
Expand All @@ -649,7 +650,7 @@ public void onResponse(ThresholdResultResponse response) {
@Override
public void onFailure(Exception e) {
try {
handlePredictionFailure(e, modelID, thresholdNodeID, failure);
handlePredictionFailure(e, adID, thresholdNodeID, failure);
} catch (Exception ex) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
Expand Down Expand Up @@ -807,6 +808,8 @@ public Boolean call() {
"Time out while indexing cold start checkpoint or get training data",
timeoutEx
);
} catch (EndRunException endRunEx) {
throw endRunEx;
} catch (Exception ex) {
throw new EndRunException(detector.getDetectorId(), "Error while cold start", ex, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ Optional<Boolean> checkResult() {
if (cause instanceof AnomalyDetectionException) {
AnomalyDetectionException adException = (AnomalyDetectionException) cause;
currentExceptions.put(adException.getAnomalyDetectorId(), adException);
LOG.info("added cause for {}", adException.getAnomalyDetectorId());
} else {
LOG.error("Get an unexpected exception");
}
}
return Optional.empty();
Expand All @@ -79,9 +82,12 @@ Optional<Boolean> checkResult() {
public Optional<AnomalyDetectionException> fetchException(String adID) {
checkResult();

if (currentExceptions.containsKey(adID)) {
LOG.error("Found matching exception for {}", adID);
AnomalyDetectionException ex = currentExceptions.get(adID);
if (ex != null) {
LOG.error("Found a matching exception for " + adID, ex);
return Optional.of(currentExceptions.remove(adID));
} else {
LOG.info("Cannot find a matching exception for {}", adID);
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
import com.amazon.opendistroforelasticsearch.ad.AbstractADTest;
import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing;
import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.ClientException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure;
import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.ResourceNotFoundException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonName;
import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager;
import com.amazon.opendistroforelasticsearch.ad.feature.SinglePointFeatures;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
Expand Down Expand Up @@ -105,6 +105,7 @@
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.NodeNotConnectedException;
Expand Down Expand Up @@ -165,6 +166,7 @@ public void setUp() throws Exception {
super.setUpLog4jForJUnit(AnomalyResultTransportAction.class);
setupTestNodes(Settings.EMPTY);
FakeNode.connectNodes(testNodes);
runner = new ColdStartRunner();
transportService = testNodes[0].transportService;
clusterService = testNodes[0].clusterService;
stateManager = mock(ADStateManager.class);
Expand Down Expand Up @@ -264,7 +266,6 @@ public void setupTestNodes(Settings settings) {
for (int i = 0; i < testNodes.length; i++) {
testNodes[i] = new FakeNode("node" + i, threadPool, settings);
}
runner = new ColdStartRunner();
}

@Override
Expand All @@ -273,6 +274,7 @@ public final void tearDown() throws Exception {
for (FakeNode testNode : testNodes) {
testNode.close();
}
testNodes = null;
runner.shutDown();
runner = null;
client = null;
Expand Down Expand Up @@ -887,7 +889,7 @@ public void testColdStartNoTrainingData() throws Exception {
);

AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector);
expectThrows(AnomalyDetectionException.class, () -> job.call());
expectThrows(EndRunException.class, () -> job.call());
}

public void testColdStartTimeoutPutCheckpoint() throws Exception {
Expand All @@ -912,7 +914,30 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception {
);

AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector);
expectThrows(ClientException.class, () -> job.call());
expectThrows(InternalFailure.class, () -> job.call());
}

public void testColdStartIllegalArgumentException() throws Exception {
when(featureQuery.getColdStartData(any(AnomalyDetector.class))).thenReturn(Optional.of(new double[][] { { 1.0 } }));
doThrow(new IllegalArgumentException("")).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class));

AnomalyResultTransportAction action = new AnomalyResultTransportAction(
new ActionFilters(Collections.emptySet()),
transportService,
settings,
stateManager,
runner,
featureQuery,
normalModelManager,
hashRing,
clusterService,
indexNameResolver,
adCircuitBreakerService,
adStats
);

AnomalyResultTransportAction.ColdStartJob job = action.new ColdStartJob(detector);
expectThrows(EndRunException.class, () -> job.call());
}

enum FeatureTestMode {
Expand Down Expand Up @@ -1089,8 +1114,7 @@ public void testNullRCFResult() {

@SuppressWarnings("unchecked")
public void testAllFeaturesDisabled() {
// doThrow(IllegalArgumentException.class).when(featureQuery)
// .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class));

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
ActionListener<SinglePointFeatures> listener = (ActionListener<SinglePointFeatures>) args[3];
Expand Down Expand Up @@ -1120,4 +1144,45 @@ public void testAllFeaturesDisabled() {

assertException(listener, EndRunException.class, AnomalyResultTransportAction.ALL_FEATURES_DISABLED_ERR_MSG);
}

@SuppressWarnings("unchecked")
public void testEndRunDueToNoTrainingData() {
ModelManager rcfManager = mock(ModelManager.class);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
ActionListener<RcfResult> listener = (ActionListener<RcfResult>) args[3];
listener.onFailure(new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME));
return null;
}).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID);

ColdStartRunner mockRunner = mock(ColdStartRunner.class);
when(mockRunner.fetchException(any(String.class)))
.thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false)));

// These constructors register handler in transport service
new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService);
new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager);

AnomalyResultTransportAction action = new AnomalyResultTransportAction(
new ActionFilters(Collections.emptySet()),
transportService,
settings,
stateManager,
mockRunner,
featureQuery,
normalModelManager,
hashRing,
clusterService,
indexNameResolver,
adCircuitBreakerService,
adStats
);

AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200);
PlainActionFuture<AnomalyResultResponse> listener = new PlainActionFuture<>();
action.doExecute(null, request, listener);

assertException(listener, EndRunException.class);
}
}

0 comments on commit cf48892

Please sign in to comment.