diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index 680cda5396779..adb257525bcfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -92,7 +92,7 @@ public Classification(String actualField, } private static List defaultMetrics() { - return Arrays.asList(new MulticlassConfusionMatrix()); + return Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 978ac0c74cded..4c6ee1af7f6bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -80,6 +80,10 @@ public Huber(StreamInput in) throws IOException { this.delta = in.readDouble(); } + public Huber() { + this(DEFAULT_DELTA); + } + public Huber(@Nullable Double delta) { this.delta = delta != null ? delta : DEFAULT_DELTA; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index a90e6821255d7..a8bf53c781272 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -76,7 +76,7 @@ public Regression(String actualField, String predictedField, @Nullable List defaultMetrics() { - return Arrays.asList(new MeanSquaredError(), new RSquared()); + return Arrays.asList(new MeanSquaredError(), new RSquared(), new Huber()); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index c511ac42ee4ca..60c30b94a7fb7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -40,6 +40,7 @@ import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -110,6 +111,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Classification classification = new Classification("actual", "predicted", null, null); + + List metrics = classification.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())); + } + public void testGetFields() { Classification evaluation = new Classification("foo", "bar", "results", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index c0b72dbe1c234..d68daf218882f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -89,6 +90,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + OutlierDetection outlierDetection = new OutlierDetection("actual", "predicted", null); + + List metrics = outlierDetection.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new AucRoc(false), + new Precision(Arrays.asList(0.25, 0.5, 0.75)), + new Recall(Arrays.asList(0.25, 0.5, 0.75)), + new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)))); + } + public void testGetFields() { OutlierDetection evaluation = new OutlierDetection("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index c8fc2d5d67d55..c6d060a3b84e4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -76,6 +77,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Regression regression = new Regression("actual", "predicted", null); + + List metrics = regression.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Huber(), new MeanSquaredError(), new RSquared())); + } + public void testGetFields() { Regression evaluation = new Regression("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 3deebf757c3a5..a02a3c4869b5c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -32,6 +32,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -81,7 +82,13 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MulticlassConfusionMatrix.NAME.getPreferredName())); + containsInAnyOrder( + MulticlassConfusionMatrix.NAME.getPreferredName(), + Accuracy.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java index a947fe66c03cb..a849d244a0f82 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java @@ -24,6 +24,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -52,7 +53,12 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName())); + containsInAnyOrder( + MeanSquaredError.NAME.getPreferredName(), + RSquared.NAME.getPreferredName(), + Huber.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 09cf11d266612..83fe922c02492 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -938,6 +938,10 @@ setup: } - is_true: classification.multiclass_confusion_matrix + - is_true: classification.accuracy + - is_true: classification.precision + - is_true: classification.recall + - is_false: classification.auc_roc --- "Test classification given missing actual_field": - do: @@ -1104,8 +1108,8 @@ setup: - match: { regression.mse.value: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } + - match: { regression.huber.value: 1.9205280586939963 } - is_false: regression.msle.value - - is_false: regression.huber.value --- "Test regression given missing actual_field": - do: