Skip to content

Commit

Permalink
Get rid of maxClassesCardinality internal parameter (#50418)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Dec 20, 2019
1 parent c82113e commit 9e6e4bb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -75,25 +74,15 @@ public static Precision fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
private static final int MAX_CLASSES_CARDINALITY = 1000;

private final int maxClassesCardinality;
private String actualField;
private List<String> topActualClassNames;
private EvaluationMetricResult result;

public Precision() {
this((Integer) null);
}

// Visible for testing
public Precision(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Precision() {}

public Precision(StreamInput in) throws IOException {
this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Precision(StreamInput in) throws IOException {}

@Override
public String getWriteableName() {
Expand All @@ -115,7 +104,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
.field(actualField)
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
.size(maxClassesCardinality)),
.size(MAX_CLASSES_CARDINALITY)),
List.of());
}
if (result == null) { // This is step 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -69,24 +68,14 @@ public static Recall fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
private static final int MAX_CLASSES_CARDINALITY = 1000;

private final int maxClassesCardinality;
private String actualField;
private EvaluationMetricResult result;

public Recall() {
this((Integer) null);
}

// Visible for testing
public Recall(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Recall() {}

public Recall(StreamInput in) throws IOException {
this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Recall(StreamInput in) throws IOException {}

@Override
public String getWriteableName() {
Expand All @@ -110,7 +99,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
List.of(
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField)
.size(maxClassesCardinality)
.size(MAX_CLASSES_CARDINALITY)
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
List.of(
PipelineAggregatorBuilders.avgBucket(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT

@Before
public void setup() {
createAnimalsIndex(ANIMALS_DATA_INDEX);
indexAnimalsData(ANIMALS_DATA_INDEX);
}

Expand Down Expand Up @@ -141,11 +142,12 @@ public void testEvaluate_Precision() {
}

public void testEvaluate_Precision_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision(4)))));
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
}

Expand All @@ -172,11 +174,12 @@ public void testEvaluate_Recall() {
}

public void testEvaluate_Recall_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall(4)))));
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
}

Expand Down Expand Up @@ -281,7 +284,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
}

private static void indexAnimalsData(String indexName) {
private static void createAnimalsIndex(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc",
ANIMAL_NAME_FIELD, "type=keyword",
Expand All @@ -291,7 +294,9 @@ private static void indexAnimalsData(String indexName) {
IS_PREDATOR_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
.get();
}

private static void indexAnimalsData(String indexName) {
List<String> animalNames = List.of("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand All @@ -315,4 +320,17 @@ private static void indexAnimalsData(String indexName) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}

private static void indexDistinctAnimals(String indexName, int distinctAnimalCount) {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < distinctAnimalCount; i++) {
bulkRequestBuilder.add(
new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
}

0 comments on commit 9e6e4bb

Please sign in to comment.