From 86fdfccb0b539195b45fb7f03ab79d63386c7de4 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 20 Oct 2020 16:36:13 +0300 Subject: [PATCH 1/2] [ML] Extend default evaluation metrics to all available This commit extends the set of default metrics for the data frame analytics evaluation API to all available metrics. The motivation is that if the user skips setting an explicit set of metrics, they get most of the evaluation offering. --- .../evaluation/classification/Classification.java | 2 +- .../ml/dataframe/evaluation/regression/Huber.java | 4 ++++ .../dataframe/evaluation/regression/Regression.java | 2 +- .../classification/ClassificationTests.java | 9 +++++++++ .../outlierdetection/OutlierDetectionTests.java | 12 ++++++++++++ .../evaluation/regression/RegressionTests.java | 9 +++++++++ .../rest-api-spec/test/ml/evaluate_data_frame.yml | 6 +++++- 7 files changed, 41 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index 680cda5396779..adb257525bcfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -92,7 +92,7 @@ public Classification(String actualField, } private static List defaultMetrics() { - return Arrays.asList(new MulticlassConfusionMatrix()); + return Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 978ac0c74cded..4c6ee1af7f6bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -80,6 +80,10 @@ public Huber(StreamInput in) throws IOException { this.delta = in.readDouble(); } + public Huber() { + this(DEFAULT_DELTA); + } + public Huber(@Nullable Double delta) { this.delta = delta != null ? delta : DEFAULT_DELTA; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index a90e6821255d7..a8bf53c781272 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -76,7 +76,7 @@ public Regression(String actualField, String predictedField, @Nullable List defaultMetrics() { - return Arrays.asList(new MeanSquaredError(), new RSquared()); + return Arrays.asList(new MeanSquaredError(), new RSquared(), new Huber()); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index c511ac42ee4ca..60c30b94a7fb7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -40,6 +40,7 @@ import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -110,6 +111,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Classification classification = new Classification("actual", "predicted", null, null); + + List metrics = classification.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())); + } + public void testGetFields() { Classification evaluation = new Classification("foo", "bar", "results", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index c0b72dbe1c234..d68daf218882f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -89,6 +90,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + OutlierDetection outlierDetection = new OutlierDetection("actual", "predicted", null); + + List metrics = outlierDetection.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new AucRoc(false), + new Precision(Arrays.asList(0.25, 0.5, 0.75)), + new Recall(Arrays.asList(0.25, 0.5, 0.75)), + new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)))); + } + public void testGetFields() { OutlierDetection evaluation = new OutlierDetection("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index c8fc2d5d67d55..c6d060a3b84e4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -76,6 +77,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Regression regression = new Regression("actual", "predicted", null); + + List metrics = regression.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Huber(), new MeanSquaredError(), new RSquared())); + } + public void testGetFields() { Regression evaluation = new Regression("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 09cf11d266612..83fe922c02492 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -938,6 +938,10 @@ setup: } - is_true: classification.multiclass_confusion_matrix + - is_true: classification.accuracy + - is_true: classification.precision + - is_true: classification.recall + - is_false: classification.auc_roc --- "Test classification given missing actual_field": - do: @@ -1104,8 +1108,8 @@ setup: - match: { regression.mse.value: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } + - match: { regression.huber.value: 1.9205280586939963 } - is_false: regression.msle.value - - is_false: regression.huber.value --- "Test regression given missing actual_field": - do: From b56fdfbb07fdbb7266d030d624ef225e035a2edc Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 20 Oct 2020 17:59:30 +0300 Subject: [PATCH 2/2] Fix integ tests --- .../xpack/ml/integration/ClassificationEvaluationIT.java | 9 ++++++++- .../xpack/ml/integration/RegressionEvaluationIT.java | 8 +++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 3deebf757c3a5..a02a3c4869b5c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -32,6 +32,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -81,7 +82,13 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MulticlassConfusionMatrix.NAME.getPreferredName())); + containsInAnyOrder( + MulticlassConfusionMatrix.NAME.getPreferredName(), + Accuracy.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java index a947fe66c03cb..a849d244a0f82 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java @@ -24,6 +24,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -52,7 +53,12 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName())); + containsInAnyOrder( + MeanSquaredError.NAME.getPreferredName(), + RSquared.NAME.getPreferredName(), + Huber.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() {