diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 1a79dff41e10c..9fdba68d4cda7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -66,7 +66,7 @@ default List initMetrics(@Nullable List parse * Builds the search required to collect data to compute the evaluation result * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data */ - default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) { Objects.requireNonNull(userProvidedQueryBuilder); BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() @@ -78,7 +78,8 @@ default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); for (EvaluationMetric metric : getMetrics()) { // Fetch aggregations requested by individual metrics - Tuple, List> aggs = metric.aggs(getActualField(), getPredictedField()); + Tuple, List> aggs = + metric.aggs(parameters, getActualField(), getPredictedField()); aggs.v1().forEach(searchSourceBuilder::aggregation); aggs.v2().forEach(searchSourceBuilder::aggregation); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index 8a106175ace91..a5c3d657f55e9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -28,11 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { /** * Builds the aggregation that collect required data to compute the metric + * @param parameters settings that may be needed by aggregations * @param actualField the field that stores the actual value * @param predictedField the field that stores the predicted value (class name or probability) * @return the aggregations required to compute the metric */ - Tuple, List> aggs(String actualField, String predictedField); + Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField); /** * Processes given aggregations as a step towards computing result diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java new file mode 100644 index 0000000000000..e834efc7e67f7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +/** + * Encapsulates parameters needed by evaluation. + */ +public class EvaluationParameters { + + /** + * Maximum number of buckets allowed in any single search request. + */ + private final int maxBuckets; + + public EvaluationParameters(int maxBuckets) { + this.maxBuckets = maxBuckets; + } + + public int getMaxBuckets() { + return maxBuckets; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index c6636329a65d9..7aa95e14b576e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -103,7 +104,9 @@ public String getName() { } @Override - public final Tuple, List> aggs(String actualField, String predictedField) { + public final Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. this.actualField.trySet(actualField); List aggs = new ArrayList<>(); @@ -112,7 +115,8 @@ public final Tuple, List> a aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); } if (result.get() == null) { - Tuple, List> matrixAggs = matrix.aggs(actualField, predictedField); + Tuple, List> matrixAggs = + matrix.aggs(parameters, actualField, predictedField); aggs.addAll(matrixAggs.v1()); pipelineAggs.addAll(matrixAggs.v2()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 8376382e41b95..1dc3614723dfa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -29,6 +29,7 @@ import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -73,9 +74,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { } static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class"; static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; - static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; private static final String OTHER_BUCKET_KEY = "_other_"; private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final int DEFAULT_SIZE = 10; @@ -84,6 +85,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { private final int size; private final String aggNamePrefix; private final SetOnce> topActualClassNames = new SetOnce<>(); + private final SetOnce actualClassesCardinality = new SetOnce<>(); + /** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */ + private final List actualClasses = new ArrayList<>(); private final SetOnce result = new SetOnce<>(); public MulticlassConfusionMatrix() { @@ -122,34 +126,45 @@ public int getSize() { } @Override - public final Tuple, List> aggs(String actualField, String predictedField) { - if (topActualClassNames.get() == null) { // This is step 1 + public final Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { + if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1 return Tuple.tuple( Arrays.asList( AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) .field(actualField) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) - .size(size)), + .size(size), + AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) + .field(actualField)), Collections.emptyList()); } - if (result.get() == null) { // This is step 2 - KeyedFilter[] keyedFiltersActual = - topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) - .toArray(KeyedFilter[]::new); + if (result.get() == null) { // These are steps 2, 3, 4 etc. KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); - return Tuple.tuple( - Arrays.asList( - AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)) - .field(actualField), - AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual) - .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted) - .otherBucket(true) - .otherBucketKey(OTHER_BUCKET_KEY))), - Collections.emptyList()); + // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that + // too_many_buckets_exception exception is not thrown. + // The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch. + // In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets". + int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1); + KeyedFilter[] keyedFiltersActual = + topActualClassNames.get().stream() + .skip(actualClasses.size()) + .limit(actualClassesPerBatch) + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .toArray(KeyedFilter[]::new); + if (keyedFiltersActual.length > 0) { + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted) + .otherBucket(true) + .otherBucketKey(OTHER_BUCKET_KEY))), + Collections.emptyList()); + } } return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } @@ -160,10 +175,12 @@ public void process(Aggregations aggs) { Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)); topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); } + if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) { + Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)); + actualClassesCardinality.set(cardinalityAgg.getValue()); + } if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { - Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)); Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)); - List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); long actualClassDocCount = bucket.getDocCount(); @@ -182,7 +199,9 @@ public void process(Aggregations aggs) { predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); } - result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0))); + if (actualClasses.size() == topActualClassNames.get().size()) { + result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0))); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 30906efd41f9c..d3f0a259c160d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -30,6 +30,7 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -97,7 +98,9 @@ public String getName() { } @Override - public final Tuple, List> aggs(String actualField, String predictedField) { + public final Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. this.actualField.trySet(actualField); if (topActualClassNames.get() == null) { // This is step 1 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index c4f2e8e60ab8c..fa5b277daa45d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -26,6 +26,7 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -90,7 +91,9 @@ public String getName() { } @Override - public final Tuple, List> aggs(String actualField, String predictedField) { + public final Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. this.actualField.trySet(actualField); if (result.get() != null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index f2abbe54454f0..fb197b543140f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; import java.text.MessageFormat; @@ -67,7 +68,9 @@ public String getName() { } @Override - public Tuple, List> aggs(String actualField, String predictedField) { + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index c7b989dca1182..dd22745fa4614 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; import java.text.MessageFormat; @@ -72,7 +73,9 @@ public String getName() { } @Override - public Tuple, List> aggs(String actualField, String predictedField) { + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java index 34667aaabc9b7..451d3c49e1f95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -66,7 +67,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public Tuple, List> aggs(String actualField, String predictedProbabilityField) { + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedProbabilityField) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java index 614d351c887bb..1cd33ea8288ac 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.metrics.Percentiles; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -127,7 +128,9 @@ public int hashCode() { } @Override - public Tuple, List> aggs(String actualField, String predictedProbabilityField) { + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedProbabilityField) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParametersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParametersTests.java new file mode 100644 index 0000000000000..1efe45fa930e4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParametersTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class EvaluationParametersTests extends ESTestCase { + + public void testConstructorAndGetters() { + EvaluationParameters params = new EvaluationParameters(17); + assertThat(params.getMaxBuckets(), equalTo(17)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index 5fe49b9e8119d..35a9a85d135a3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; @@ -17,19 +18,21 @@ import java.util.Arrays; import java.util.Collections; +import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; -import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class AccuracyTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected Accuracy doParseInstance(XContentParser parser) throws IOException { return Accuracy.fromXContent(parser); @@ -62,6 +65,7 @@ public void testProcess() { mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 100L), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1000L), mockFilters( "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( @@ -79,13 +83,12 @@ public void testProcess() { "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L), mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); Accuracy accuracy = new Accuracy(); accuracy.process(aggs); - assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); Result result = accuracy.getResult().get(); assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); @@ -106,6 +109,7 @@ public void testProcess_GivenCardinalityTooHigh() { mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 100L), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1001L), mockFilters( "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( @@ -123,11 +127,10 @@ public void testProcess_GivenCardinalityTooHigh() { "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L), mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); Accuracy accuracy = new Accuracy(); - accuracy.aggs("foo", "bar"); + accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar"); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index 132195e78d1d3..ed1789c3d3875 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; @@ -43,6 +44,8 @@ public class ClassificationTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); @@ -100,7 +103,7 @@ public void testBuildSearch() { Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix())); - SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); } @@ -196,7 +199,9 @@ public String getWriteableName() { } @Override - public Tuple, List> aggs(String actualField, String predictedField) { + public Tuple, List> aggs(EvaluationParameters parameters, + String actualField, + String predictedField) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 58390744b165b..e6662c0429bde 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; @@ -23,18 +24,20 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; -import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException { return MulticlassConfusionMatrix.fromXContent(parser); @@ -80,12 +83,12 @@ public void testConstructor_SizeValidationFailures() { public void testAggs() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); - Tuple, List> aggs = confusionMatrix.aggs("act", "pred"); + Tuple, List> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"); assertThat(aggs, isTuple(not(empty()), empty())); assertThat(confusionMatrix.getResult(), isEmpty()); } - public void testEvaluate() { + public void testProcess() { Aggregations aggs = new Aggregations(Arrays.asList( mockTerms( MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, @@ -93,6 +96,7 @@ public void testEvaluate() { mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 0L), + mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L), mockFilters( MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( @@ -109,13 +113,13 @@ public void testEvaluate() { new Aggregations(Arrays.asList(mockFilters( MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( - mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L))); + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))) + )); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( @@ -127,7 +131,7 @@ public void testEvaluate() { assertThat(result.getOtherActualClassCount(), equalTo(0L)); } - public void testEvaluate_OtherClassesCountGreaterThanZero() { + public void testProcess_OtherClassesCountGreaterThanZero() { Aggregations aggs = new Aggregations(Arrays.asList( mockTerms( MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, @@ -135,6 +139,7 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 100L), + mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 5L), mockFilters( MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( @@ -151,13 +156,13 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { new Aggregations(Arrays.asList(mockFilters( MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( - mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), - mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L))); + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))) + )); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( @@ -168,4 +173,106 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); assertThat(result.getOtherActualClassCount(), equalTo(3L)); } + + public void testProcess_MoreThanTwoStepsNeeded() { + Aggregations aggsStep1 = new Aggregations(Arrays.asList( + mockTerms( + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockTermsBucket("ant", new Aggregations(Arrays.asList())), + mockTermsBucket("cat", new Aggregations(Arrays.asList())), + mockTermsBucket("dog", new Aggregations(Arrays.asList())), + mockTermsBucket("fox", new Aggregations(Arrays.asList()))), + 0L), + mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L) + )); + Aggregations aggsStep2 = new Aggregations(Arrays.asList( + mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockFiltersBucket( + "ant", + 46, + new Aggregations(Arrays.asList(mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("ant", 10L), + mockFiltersBucket("cat", 11L), + mockFiltersBucket("dog", 12L), + mockFiltersBucket("fox", 13L), + mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 86, + new Aggregations(Arrays.asList(mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("ant", 20L), + mockFiltersBucket("cat", 21L), + mockFiltersBucket("dog", 22L), + mockFiltersBucket("fox", 23L), + mockFiltersBucket("_other_", 0L)))))))) + )); + Aggregations aggsStep3 = new Aggregations(Arrays.asList( + mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockFiltersBucket( + "dog", + 126, + new Aggregations(Arrays.asList(mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("ant", 30L), + mockFiltersBucket("cat", 31L), + mockFiltersBucket("dog", 32L), + mockFiltersBucket("fox", 33L), + mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "fox", + 166, + new Aggregations(Arrays.asList(mockFilters( + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("ant", 40L), + mockFiltersBucket("cat", 41L), + mockFiltersBucket("dog", 42L), + mockFiltersBucket("fox", 43L), + mockFiltersBucket("_other_", 0L)))))))) + )); + + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(4, null); + confusionMatrix.process(aggsStep1); + confusionMatrix.process(aggsStep2); + confusionMatrix.process(aggsStep3); + + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + assertThat( + result.getConfusionMatrix(), + equalTo( + Arrays.asList( + new ActualClass("ant", 46, Arrays.asList( + new PredictedClass("ant", 10L), + new PredictedClass("cat", 11L), + new PredictedClass("dog", 12L), + new PredictedClass("fox", 13L)), 0), + new ActualClass("cat", 86, Arrays.asList( + new PredictedClass("ant", 20L), + new PredictedClass("cat", 21L), + new PredictedClass("dog", 22L), + new PredictedClass("fox", 23L)), 0), + new ActualClass("dog", 126, Arrays.asList( + new PredictedClass("ant", 30L), + new PredictedClass("cat", 31L), + new PredictedClass("dog", 32L), + new PredictedClass("fox", 33L)), 0), + new ActualClass("fox", 166, Arrays.asList( + new PredictedClass("ant", 40L), + new PredictedClass("cat", 41L), + new PredictedClass("dog", 42L), + new PredictedClass("fox", 43L)), 0)))); + assertThat(result.getOtherActualClassCount(), equalTo(0L)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java index e9bb1b176db69..81c734863408d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java @@ -10,22 +10,25 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; -import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class PrecisionTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected Precision doParseInstance(XContentParser parser) throws IOException { return Precision.fromXContent(parser); @@ -61,7 +64,7 @@ public void testProcess() { Precision precision = new Precision(); precision.process(aggs); - assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123))); } @@ -111,7 +114,7 @@ public void testProcess_GivenCardinalityTooHigh() { Aggregations aggs = new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1))); Precision precision = new Precision(); - precision.aggs("foo", "bar"); + precision.aggs(EVALUATION_PARAMETERS, "foo", "bar"); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java index 34ef9ca28bcd3..efced860b9192 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java @@ -10,21 +10,24 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; -import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class RecallTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected Recall doParseInstance(XContentParser parser) throws IOException { return Recall.fromXContent(parser); @@ -59,7 +62,7 @@ public void testProcess() { Recall recall = new Recall(); recall.process(aggs); - assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123))); } @@ -110,7 +113,7 @@ public void testProcess_GivenCardinalityTooHigh() { mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1), mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123))); Recall recall = new Recall(); - recall.aggs("foo", "bar"); + recall.aggs(EVALUATION_PARAMETERS, "foo", "bar"); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 96ba97ecc9348..26dff097b1b32 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; @@ -28,6 +29,8 @@ public class RegressionTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); @@ -85,7 +88,7 @@ public void testBuildSearch() { Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError())); - SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java index 28e0a045b190d..2a9645e094291 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; @@ -28,6 +29,8 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase { + private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); @@ -98,7 +101,7 @@ public void testBuildSearch() { BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7)))); - SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index c2758a2b653a4..3985d9c18105e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -5,18 +5,21 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyBucketsException; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.junit.After; import org.junit.Before; @@ -28,6 +31,8 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -49,6 +54,10 @@ public void setup() { @After public void cleanup() { cleanUp(); + client().admin().cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder().putNull("search.max_buckets")) + .get(); } public void testEvaluate_DefaultMetrics() { @@ -208,7 +217,7 @@ public void testEvaluate_Recall_CardinalityTooHigh() { assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); } - public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { + private void evaluateWithMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, @@ -271,6 +280,23 @@ public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } + public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { + evaluateWithMulticlassConfusionMatrix(); + + client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get(); + evaluateWithMulticlassConfusionMatrix(); + + client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get(); + evaluateWithMulticlassConfusionMatrix(); + + client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get(); + ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix); + + assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class))); + TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause(); + assertThat(tmbe.getMaxBuckets(), equalTo(6)); + } + public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java index 5c48be663f117..c6d92413dd392 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.Task; @@ -18,26 +19,41 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING; public class TransportEvaluateDataFrameAction extends HandledTransportAction { private final ThreadPool threadPool; private final Client client; + private final AtomicReference maxBuckets = new AtomicReference<>(); @Inject - public TransportEvaluateDataFrameAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool, - Client client) { + public TransportEvaluateDataFrameAction(TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + ClusterService clusterService) { super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new); this.threadPool = threadPool; this.client = client; + this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings())); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets); + } + + private void setMaxBuckets(int maxBuckets) { + this.maxBuckets.set(maxBuckets); } @Override - protected void doExecute(Task task, EvaluateDataFrameAction.Request request, + protected void doExecute(Task task, + EvaluateDataFrameAction.Request request, ActionListener listener) { ActionListener> resultsListener = ActionListener.wrap( unused -> { @@ -48,7 +64,9 @@ protected void doExecute(Task task, EvaluateDataFrameAction.Request request, listener::onFailure ); - EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request); + // Create an immutable collection of parameters to be used by evaluation metrics. + EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get()); + EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request); evaluationExecutor.execute(resultsListener); } @@ -68,12 +86,14 @@ protected void doExecute(Task task, EvaluateDataFrameAction.Request request, private static final class EvaluationExecutor extends TypedChainTaskExecutor { private final Client client; + private final EvaluationParameters parameters; private final EvaluateDataFrameAction.Request request; private final Evaluation evaluation; - EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) { + EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) { super(threadPool.generic(), unused -> true, unused -> true); this.client = client; + this.parameters = parameters; this.request = request; this.evaluation = request.getEvaluation(); // Add one task only. Other tasks will be added as needed by the nextTask method itself. @@ -82,7 +102,7 @@ private static final class EvaluationExecutor extends TypedChainTaskExecutor nextTask() { return listener -> { - SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery()); + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery()); SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder); client.execute( SearchAction.INSTANCE,