From 8e66ea305ca4e7d0dfc5072cec80fc814e7b37aa Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 12 Mar 2020 10:09:43 +0100 Subject: [PATCH] Make accuracy evaluation metric work when there is field mapping type mismatch --- .../evaluation/classification/Accuracy.java | 11 +- .../classification/PainlessScripts.java | 34 +++ .../evaluation/classification/Precision.java | 9 +- .../evaluation/classification/Recall.java | 11 +- .../ClassificationEvaluationIT.java | 195 +++++++++++++++--- 5 files changed, 207 insertions(+), 53 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 7aa95e14b576e..972384b09e943 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -28,11 +28,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; @@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric { static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; - - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); public static Accuracy fromXContent(XContentParser parser) { @@ -112,7 +104,8 @@ public final Tuple, List> a List aggs = new ArrayList<>(); List pipelineAggs = new ArrayList<>(); if (overallAccuracy.get() == null) { - aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); + Script script = PainlessScripts.buildComparisonScript(actualField, predictedField); + aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script)); } if (result.get() == null) { Tuple, List> matrixAggs = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java new file mode 100644 index 0000000000000..812d1a8ae2439 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.script.Script; + +import java.text.MessageFormat; +import java.util.Locale; + +/** + * Painless scripts used by classification metrics in this package. + */ +final class PainlessScripts { + + /** + * Template for the comparison script. + * It uses "String.valueOf" method in case the mapping types of the two fields are different. + */ + private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE = + new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT); + + /** + * Builds field comparison script for the given actual and predicted field names. + * @param actualField name of the actual field + * @param predictedField name of the predicted field + * @return script that compares values of actualField and predictedField + */ + static Script buildComparisonScript(String actualField, String predictedField) { + return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField })); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 899c25f1c64e0..6f15d113e9f25 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -34,11 +34,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -59,17 +57,12 @@ public class Precision implements EvaluationMetric { public static final ParseField NAME = new ParseField("precision"); - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; private static final String AGG_NAME_PREFIX = "classification_precision_"; static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class"; static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision"; static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision"; - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new); public static Precision fromXContent(XContentParser parser) { @@ -116,7 +109,7 @@ public final Tuple, List> a topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); - Script script = buildScript(actualField, predictedField); + Script script = PainlessScripts.buildComparisonScript(actualField, predictedField); return Tuple.tuple( List.of( AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index b71cbecd54d4c..8a6964933e471 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.BucketOrder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders; import org.elasticsearch.search.aggregations.bucket.terms.Terms; @@ -30,11 +31,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; @@ -54,16 +53,11 @@ public class Recall implements EvaluationMetric { public static final ParseField NAME = new ParseField("recall"); - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; private static final String AGG_NAME_PREFIX = "classification_recall_"; static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall"; static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall"; - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new); public static Recall fromXContent(XContentParser parser) { @@ -98,11 +92,12 @@ public final Tuple, List> a if (result.get() != null) { return Tuple.tuple(List.of(), List.of()); } - Script script = buildScript(actualField, predictedField); + Script script = PainlessScripts.buildComparisonScript(actualField, predictedField); return Tuple.tuple( List.of( AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) .field(actualField) + .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) .size(MAX_CLASSES_CARDINALITY) .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), List.of( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index bdcfac276eaaf..773f91b5027d2 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -37,11 +37,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; - private static final String ANIMAL_NAME_FIELD = "animal_name"; + private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; - private static final String NO_LEGS_FIELD = "no_legs"; + private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; + private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; - private static final String IS_PREDATOR_FIELD = "predator"; + private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; + private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; @Before @@ -61,7 +63,7 @@ public void cleanup() { public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -74,7 +76,7 @@ public void testEvaluate_AllMetrics() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_FIELD, + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); @@ -91,7 +93,7 @@ public void testEvaluate_AllMetrics() { public void testEvaluate_Accuracy_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy()))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -110,10 +112,9 @@ public void testEvaluate_Accuracy_KeywordField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); } - public void testEvaluate_Accuracy_IntegerField() { + private void evaluateAccuracy_IntegerField(String actualField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, List.of(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, List.of(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -132,10 +133,18 @@ public void testEvaluate_Accuracy_IntegerField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); } - public void testEvaluate_Accuracy_BooleanField() { + public void testEvaluate_Accuracy_IntegerField() { + evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() { + evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluateAccuracy_BooleanField(String actualField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, List.of(new Accuracy()))); + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -151,10 +160,18 @@ public void testEvaluate_Accuracy_BooleanField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } - public void testEvaluate_Precision() { + public void testEvaluate_Accuracy_BooleanField() { + evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() { + evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + + public void testEvaluate_Precision_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision()))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -173,20 +190,78 @@ public void testEvaluate_Precision() { assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); } + private void evaluatePrecision_IntegerField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, List.of(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + assertThat( + precisionResult.getClasses(), + equalTo( + List.of( + new Precision.PerClassResult("1", 0.2), + new Precision.PerClassResult("2", 0.2), + new Precision.PerClassResult("3", 0.2), + new Precision.PerClassResult("4", 0.2), + new Precision.PerClassResult("5", 0.2)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(0.2)); + } + + public void testEvaluate_Precision_IntegerField() { + evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() { + evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluatePrecision_BooleanField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + assertThat( + precisionResult.getClasses(), + equalTo( + List.of( + new Precision.PerClassResult("false", 0.5), + new Precision.PerClassResult("true", 9.0 / 13)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52)); + } + + public void testEvaluate_Precision_BooleanField() { + evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() { + evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + public void testEvaluate_Precision_CardinalityTooHigh() { indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision())))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision())))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - public void testEvaluate_Recall() { + public void testEvaluate_Recall_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall()))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -205,21 +280,79 @@ public void testEvaluate_Recall() { assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); } + private void evaluateRecall_IntegerField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, List.of(new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + assertThat( + recallResult.getClasses(), + equalTo( + List.of( + new Recall.PerClassResult("1", 1.0), + new Recall.PerClassResult("2", 1.0), + new Recall.PerClassResult("3", 1.0), + new Recall.PerClassResult("4", 1.0), + new Recall.PerClassResult("5", 1.0)))); + assertThat(recallResult.getAvgRecall(), equalTo(1.0)); + } + + public void testEvaluate_Recall_IntegerField() { + evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() { + evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluateRecall_BooleanField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, List.of(new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + assertThat( + recallResult.getClasses(), + equalTo( + List.of( + new Recall.PerClassResult("true", 0.6), + new Recall.PerClassResult("false", 0.6)))); + assertThat(recallResult.getAvgRecall(), equalTo(0.6)); + } + + public void testEvaluate_Recall_BooleanField() { + evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() { + evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + public void testEvaluate_Recall_CardinalityTooHigh() { indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall())))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall())))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } private void evaluateWithMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix()))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -299,7 +432,8 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); + new Classification( + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -336,11 +470,13 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .setMapping( - ANIMAL_NAME_FIELD, "type=keyword", + ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", - NO_LEGS_FIELD, "type=integer", + NO_LEGS_KEYWORD_FIELD, "type=keyword", + NO_LEGS_INTEGER_FIELD, "type=integer", NO_LEGS_PREDICTION_FIELD, "type=integer", - IS_PREDATOR_FIELD, "type=boolean", + IS_PREDATOR_KEYWORD_FIELD, "type=keyword", + IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", IS_PREDATOR_PREDICTION_FIELD, "type=boolean") .get(); } @@ -355,11 +491,13 @@ private static void indexAnimalsData(String indexName) { bulkRequestBuilder.add( new IndexRequest(indexName) .source( - ANIMAL_NAME_FIELD, animalNames.get(i), + ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), - NO_LEGS_FIELD, i + 1, + NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), + NO_LEGS_INTEGER_FIELD, i + 1, NO_LEGS_PREDICTION_FIELD, j + 1, - IS_PREDATOR_FIELD, i % 2 == 0, + IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), + IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); } } @@ -375,7 +513,8 @@ private static void indexDistinctAnimals(String indexName, int distinctAnimalCou .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < distinctAnimalCount; i++) { bulkRequestBuilder.add( - new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + new IndexRequest(indexName) + .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); } BulkResponse bulkResponse = bulkRequestBuilder.get(); if (bulkResponse.hasFailures()) {