diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index f6bacbec05e7b..dd04f23710118 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -75,25 +74,15 @@ public static Precision fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private String actualField; private List topActualClassNames; private EvaluationMetricResult result; - public Precision() { - this((Integer) null); - } - - // Visible for testing - public Precision(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - } + public Precision() {} - public Precision(StreamInput in) throws IOException { - this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; - } + public Precision(StreamInput in) throws IOException {} @Override public String getWriteableName() { @@ -115,7 +104,7 @@ public final Tuple, List> a AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) .field(actualField) .order(List.of(BucketOrder.count(false), BucketOrder.key(true))) - .size(maxClassesCardinality)), + .size(MAX_CLASSES_CARDINALITY)), List.of()); } if (result == null) { // This is step 2 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 522810e57e2dd..01bdbe6db230b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -69,24 +68,14 @@ public static Recall fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private String actualField; private EvaluationMetricResult result; - public Recall() { - this((Integer) null); - } - - // Visible for testing - public Recall(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - } + public Recall() {} - public Recall(StreamInput in) throws IOException { - this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY; - } + public Recall(StreamInput in) throws IOException {} @Override public String getWriteableName() { @@ -110,7 +99,7 @@ public final Tuple, List> a List.of( AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) .field(actualField) - .size(maxClassesCardinality) + .size(MAX_CLASSES_CARDINALITY) .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), List.of( PipelineAggregatorBuilders.avgBucket( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 14c2c3c9aca61..437b2ddbf5180 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -38,6 +38,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT @Before public void setup() { + createAnimalsIndex(ANIMALS_DATA_INDEX); indexAnimalsData(ANIMALS_DATA_INDEX); } @@ -141,11 +142,12 @@ public void testEvaluate_Precision() { } public void testEvaluate_Precision_CardinalityTooHigh() { + indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision(4))))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); } @@ -172,11 +174,12 @@ public void testEvaluate_Recall() { } public void testEvaluate_Recall_CardinalityTooHigh() { + indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall(4))))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); } @@ -281,7 +284,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } - private static void indexAnimalsData(String indexName) { + private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .addMapping("_doc", ANIMAL_NAME_FIELD, "type=keyword", @@ -291,7 +294,9 @@ private static void indexAnimalsData(String indexName) { IS_PREDATOR_FIELD, "type=boolean", IS_PREDATOR_PREDICTION_FIELD, "type=boolean") .get(); + } + private static void indexAnimalsData(String indexName) { List animalNames = List.of("dog", "cat", "mouse", "ant", "fox"); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -315,4 +320,17 @@ private static void indexAnimalsData(String indexName) { fail("Failed to index data: " + bulkResponse.buildFailureMessage()); } } + + private static void indexDistinctAnimals(String indexName, int distinctAnimalCount) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < distinctAnimalCount; i++) { + bulkRequestBuilder.add( + new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } }