-
Notifications
You must be signed in to change notification settings - Fork 36
add test evaluating anomaly results #13
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
/* | ||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). | ||
* You may not use this file except in compliance with the License. | ||
* A copy of the License is located at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed | ||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
* express or implied. See the License for the specific language governing | ||
* permissions and limitations under the License. | ||
*/ | ||
|
||
package com.amazon.opendistroforelasticsearch.ad.e2e; | ||
|
||
import java.io.File; | ||
import java.io.FileReader; | ||
import java.time.Instant; | ||
import java.time.format.DateTimeFormatter; | ||
import java.time.temporal.ChronoUnit; | ||
import java.util.AbstractMap.SimpleEntry; | ||
import java.util.ArrayList; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.Set; | ||
|
||
import com.google.gson.JsonArray; | ||
import com.google.gson.JsonObject; | ||
import com.google.gson.JsonParser; | ||
import org.elasticsearch.client.Request; | ||
import org.elasticsearch.client.RestClient; | ||
import org.elasticsearch.test.rest.ESRestTestCase; | ||
|
||
import static org.junit.Assert.assertTrue; | ||
|
||
public class DetectionResultEvalutationIT extends ESRestTestCase { | ||
|
||
public void testDataset() throws Exception { | ||
verifyAnomaly("synthetic", 1, 1500, 8, .9, .9, 10); | ||
} | ||
|
||
private void verifyAnomaly(String datasetName, int intervalMinutes, int trainTestSplit, int shingleSize, | ||
double minPrecision, double minRecall, double maxError) throws Exception { | ||
|
||
RestClient client = client(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How many nodes do we have for this cluster? We want both single node and multi-node cluster to test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default setting is a single node. I will add a multi-node test case in a separate pr. |
||
|
||
String dataFileName = String.format("data/%s.data", datasetName); | ||
String labelFileName = String.format("data/%s.label", datasetName); | ||
|
||
List<JsonObject> data = getData(dataFileName); | ||
List<Entry<Instant, Instant>> anomalies = getAnomalyWindows(labelFileName); | ||
|
||
indexTrainData(datasetName, data, trainTestSplit, client); | ||
String detectorId = createDetector(datasetName, intervalMinutes, client); | ||
startDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); | ||
|
||
indexTestData(data, datasetName, trainTestSplit, client); | ||
double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); | ||
verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); | ||
} | ||
|
||
private void verifyTestResults(double[] testResults, List<Entry<Instant, Instant>> anomalies, double minPrecision, double minRecall, | ||
double maxError) { | ||
|
||
double positives = testResults[0]; | ||
double truePositives = testResults[1]; | ||
double positiveAnomalies = testResults[2]; | ||
double errors = testResults[3]; | ||
|
||
// precision = predicted anomaly points that are true / predicted anomaly points | ||
double precision = positives > 0 ? truePositives / positives : 1; | ||
assertTrue(precision >= minPrecision); | ||
|
||
// recall = windows containing predicted anomaly points / total anomaly windows | ||
double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment to explain how you compute precision and recall? It's not easy to understand as precision is computed based on points and recall is based on windows. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comments showing the definitions in the update |
||
assertTrue(recall >= minRecall); | ||
|
||
assertTrue(errors <= maxError); | ||
} | ||
|
||
private int isAnomaly(Instant time, List<Entry<Instant, Instant>> labels) { | ||
for (int i = 0; i < labels.size(); i++) { | ||
Entry<Instant, Instant> window = labels.get(i); | ||
if (time.compareTo(window.getKey()) >=0 && time.compareTo(window.getValue()) <= 0) { | ||
return i; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
private double[] getTestResults(String detectorId, List<JsonObject> data, int trainTestSplit, int intervalMinutes, | ||
List<Entry<Instant, Instant>> anomalies, RestClient client) throws Exception { | ||
|
||
double positives = 0; | ||
double truePositives = 0; | ||
Set<Integer> positiveAnomalies = new HashSet<>(); | ||
double errors = 0; | ||
for (int i = trainTestSplit; i < data.size(); i++) { | ||
Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); | ||
Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); | ||
try { | ||
Map<String, Object> response = getDetectionResult(detectorId, begin, end, client); | ||
double anomalyGrade = (double)response.get("anomalyGrade"); | ||
if (anomalyGrade > 0) { | ||
positives++; | ||
int result = isAnomaly(begin, anomalies); | ||
if (result != -1) { | ||
truePositives++; | ||
positiveAnomalies.add(result); | ||
} | ||
} | ||
} catch (Exception e) { | ||
errors++; | ||
e.printStackTrace(); | ||
} | ||
} | ||
return new double[] {positives, truePositives, positiveAnomalies.size(), errors}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will truePositives and positiveAnomalies.size() always be the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, because truePositives is a count of correct data points. positiveAnomalies is a count of correctly found anomaly windows. The former is no less than the latter. |
||
} | ||
|
||
private void indexTestData(List<JsonObject> data, String datasetName, int trainTestSplit, RestClient client) throws Exception { | ||
data.stream().skip(trainTestSplit) | ||
.forEach(r -> { | ||
try { | ||
Request req = new Request("POST", String.format("/%s/_doc/", datasetName)); | ||
req.setJsonEntity(r.toString()); | ||
client.performRequest(req); | ||
} catch (Exception e ) { | ||
throw new RuntimeException(e); | ||
} }); | ||
Thread.sleep(1_000); | ||
} | ||
|
||
private void startDetector(String detectorId, List<JsonObject> data, int trainTestSplit, int shingleSize, int intervalMinutes, | ||
RestClient client) throws Exception { | ||
|
||
Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit-1).get("timestamp").getAsString())); | ||
|
||
Instant begin = null; | ||
Instant end = null; | ||
for (int i = 0; i < shingleSize; i++) { | ||
begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); | ||
end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); | ||
try { | ||
getDetectionResult(detectorId, begin, end, client); | ||
} catch (Exception e) { | ||
} | ||
} | ||
Thread.sleep(5_000); | ||
getDetectionResult(detectorId, begin, end, client); | ||
} | ||
|
||
private String createDetector(String datasetName, int intervalMinutes, RestClient client) throws Exception { | ||
Request request = new Request("POST", "/_opendistro/_anomaly_detection/detectors/"); | ||
String requestBody = String.format("{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" | ||
+ ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " | ||
+ "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" | ||
+ ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " | ||
+ "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " | ||
+ "\"schema_version\": 0 }", datasetName, intervalMinutes); | ||
request.setJsonEntity(requestBody); | ||
Map<String, Object> response = entityAsMap(client.performRequest(request)); | ||
String detectorId = (String)response.get("_id"); | ||
Thread.sleep(1_000); | ||
return detectorId; | ||
} | ||
|
||
private List<Entry<Instant, Instant>> getAnomalyWindows(String labalFileName) throws Exception { | ||
JsonArray windows = new JsonParser().parse(new FileReader(new File(getClass().getResource(labalFileName).toURI()))) | ||
.getAsJsonArray(); | ||
List<Entry<Instant, Instant>> anomalies = new ArrayList<>(windows.size()); | ||
for (int i = 0; i < windows.size(); i++) { | ||
JsonArray window = windows.get(i).getAsJsonArray(); | ||
Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); | ||
Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); | ||
anomalies.add(new SimpleEntry<>(begin, end)); | ||
} | ||
return anomalies; | ||
} | ||
|
||
private void indexTrainData(String datasetName, List<JsonObject> data, int trainTestSplit, RestClient client) throws Exception { | ||
Request request = new Request("PUT", datasetName); | ||
String requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," | ||
+ " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; | ||
request.setJsonEntity(requestBody); | ||
client.performRequest(request); | ||
Thread.sleep(1_000); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that sometimes sleeping 1 second is not enough? Have you tried to run the test multiple times (say 100 times)? If yes, would it always pass? I am afraid concurrency and inter-process communication(if you are running the test in a simulated multi-node cluster) would sometimes cause the test to fail. Also, would https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-refresh.html help? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have run it for many times (certainly more than 30 times and less than 100 times) without failure. And it passes github checks multiple times. From my experience, it has been more reliable than some randomized unit tests. |
||
|
||
data.stream().limit(trainTestSplit) | ||
.forEach(r -> { | ||
try { | ||
Request req = new Request("POST", String.format("/%s/_doc/", datasetName)); | ||
req.setJsonEntity(r.toString()); | ||
client.performRequest(req); | ||
} catch (Exception e ) { | ||
throw new RuntimeException(e); | ||
} }); | ||
Thread.sleep(1_000); | ||
} | ||
|
||
private List<JsonObject> getData(String datasetFileName) throws Exception { | ||
JsonArray jsonArray = new JsonParser().parse(new FileReader(new File(getClass().getResource(datasetFileName).toURI()))) | ||
.getAsJsonArray(); | ||
List<JsonObject> list = new ArrayList<>(jsonArray.size()); | ||
jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); | ||
return list; | ||
} | ||
|
||
private Map<String, Object> getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { | ||
try { | ||
Request request = new Request("POST", String.format("/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId)); | ||
request.setJsonEntity(String.format("{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli())); | ||
return entityAsMap(client.performRequest(request)); | ||
} catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How long does the test take to run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
about a minute