Skip to content

Commit

Permalink
Make accuracy evaluation metric work when there is field mapping type…
Browse files Browse the repository at this point in the history
… mismatch
  • Loading branch information
przemekwitek committed Mar 16, 2020
1 parent 060b4ee commit 8e66ea3
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

Expand Down Expand Up @@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric {

static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";

private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";

private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}

private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);

public static Accuracy fromXContent(XContentParser parser) {
Expand Down Expand Up @@ -112,7 +104,8 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
List<AggregationBuilder> aggs = new ArrayList<>();
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
if (overallAccuracy.get() == null) {
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
Script script = PainlessScripts.buildComparisonScript(actualField, predictedField);
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
}
if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.script.Script;

import java.text.MessageFormat;
import java.util.Locale;

/**
* Painless scripts used by classification metrics in this package.
*/
final class PainlessScripts {

/**
* Template for the comparison script.
* It uses "String.valueOf" method in case the mapping types of the two fields are different.
*/
private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE =
new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT);

/**
* Builds field comparison script for the given actual and predicted field names.
* @param actualField name of the actual field
* @param predictedField name of the predicted field
* @return script that compares values of actualField and predictedField
*/
static Script buildComparisonScript(String actualField, String predictedField) {
return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField }));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -59,17 +57,12 @@ public class Precision implements EvaluationMetric {

public static final ParseField NAME = new ParseField("precision");

private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_precision_";
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";

private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}

private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);

public static Precision fromXContent(XContentParser parser) {
Expand Down Expand Up @@ -116,7 +109,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
Script script = buildScript(actualField, predictedField);
Script script = PainlessScripts.buildComparisonScript(actualField, predictedField);
return Tuple.tuple(
List.of(
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
Expand All @@ -30,11 +31,9 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

Expand All @@ -54,16 +53,11 @@ public class Recall implements EvaluationMetric {

public static final ParseField NAME = new ParseField("recall");

private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_recall_";
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";

private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}

private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);

public static Recall fromXContent(XContentParser parser) {
Expand Down Expand Up @@ -98,11 +92,12 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
if (result.get() != null) {
return Tuple.tuple(List.of(), List.of());
}
Script script = buildScript(actualField, predictedField);
Script script = PainlessScripts.buildComparisonScript(actualField, predictedField);
return Tuple.tuple(
List.of(
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField)
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
.size(MAX_CLASSES_CARDINALITY)
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
List.of(
Expand Down
Loading

0 comments on commit 8e66ea3

Please sign in to comment.