Skip to content

Commit

Permalink
Perform evaluation in multiple steps when necessary (#53295)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Mar 11, 2020
1 parent dba2a6e commit fd030dc
Show file tree
Hide file tree
Showing 21 changed files with 327 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
Expand All @@ -78,7 +78,8 @@ default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {

/**
* Builds the aggregation that collect required data to compute the metric
* @param parameters settings that may be needed by aggregations
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value (class name or probability)
* @return the aggregations required to compute the metric
*/
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField);

/**
* Processes given aggregations as a step towards computing result
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

/**
* Encapsulates parameters needed by evaluation.
*/
public class EvaluationParameters {

/**
* Maximum number of buckets allowed in any single search request.
*/
private final int maxBuckets;

public EvaluationParameters(int maxBuckets) {
this.maxBuckets = maxBuckets;
}

public int getMaxBuckets() {
return maxBuckets;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -103,7 +104,9 @@ public String getName() {
}

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
List<AggregationBuilder> aggs = new ArrayList<>();
Expand All @@ -112,7 +115,8 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
}
if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
matrix.aggs(parameters, actualField, predictedField);
aggs.addAll(matrixAggs.v1());
pipelineAggs.addAll(matrixAggs.v2());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -72,9 +73,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
}

static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class";
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_";
private static final String DEFAULT_AGG_NAME_PREFIX = "";
private static final int DEFAULT_SIZE = 10;
Expand All @@ -83,6 +84,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
private final int size;
private final String aggNamePrefix;
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private final SetOnce<Long> actualClassesCardinality = new SetOnce<>();
/** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
private final List<ActualClass> actualClasses = new ArrayList<>();
private final SetOnce<Result> result = new SetOnce<>();

public MulticlassConfusionMatrix() {
Expand Down Expand Up @@ -121,34 +125,45 @@ public int getSize() {
}

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (topActualClassNames.get() == null) { // This is step 1
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
return Tuple.tuple(
List.of(
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
.field(actualField)
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)),
.size(size),
AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS))
.field(actualField)),
List.of());
}
if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersActual =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
if (result.get() == null) { // These are steps 2, 3, 4 etc.
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
return Tuple.tuple(
List.of(
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
.field(actualField),
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))),
List.of());
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
// too_many_buckets_exception exception is not thrown.
// The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
// In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1);
KeyedFilter[] keyedFiltersActual =
topActualClassNames.get().stream()
.skip(actualClasses.size())
.limit(actualClassesPerBatch)
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
if (keyedFiltersActual.length > 0) {
return Tuple.tuple(
List.of(
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))),
List.of());
}
}
return Tuple.tuple(List.of(), List.of());
}
Expand All @@ -159,10 +174,12 @@ public void process(Aggregations aggs) {
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
}
if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS));
actualClassesCardinality.set(cardinalityAgg.getValue());
}
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount();
Expand All @@ -181,7 +198,9 @@ public void process(Aggregations aggs) {
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
}
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
if (actualClasses.size() == topActualClassNames.get().size()) {
result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0)));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -96,7 +97,9 @@ public String getName() {
}

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
if (topActualClassNames.get() == null) { // This is step 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -89,7 +90,9 @@ public String getName() {
}

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
if (result.get() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;

import java.io.IOException;
import java.text.MessageFormat;
Expand Down Expand Up @@ -65,7 +66,9 @@ public String getName() {
}

@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) {
return Tuple.tuple(List.of(), List.of());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;

import java.io.IOException;
import java.text.MessageFormat;
Expand Down Expand Up @@ -70,7 +71,9 @@ public String getName() {
}

@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) {
return Tuple.tuple(List.of(), List.of());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -65,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
if (result != null) {
return Tuple.tuple(List.of(), List.of());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -127,7 +128,9 @@ public int hashCode() {
}

@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
if (result != null) {
return Tuple.tuple(List.of(), List.of());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.test.ESTestCase;

import static org.hamcrest.Matchers.equalTo;

public class EvaluationParametersTests extends ESTestCase {

public void testConstructorAndGetters() {
EvaluationParameters params = new EvaluationParameters(17);
assertThat(params.getMaxBuckets(), equalTo(17));
}
}
Loading

0 comments on commit fd030dc

Please sign in to comment.