Skip to content

Commit

Permalink
[ML] Extend classification to support multiple classes (#53539)
Browse files Browse the repository at this point in the history
* [ML] Extend classification to support multiple classes

Prepares classification analysis to support more than just
two classes. It introduces a new parameter to the process config
which dictates the `num_classes` to the process. It also
changes the max classes limit to `30` provisionally.

* We can't test cardinality is too high in the YML tests anymore

* Extract max number of classes in a constant
  • Loading branch information
dimitris-athanasiou authored Mar 16, 2020
1 parent eaa8ead commit 7e0a3c4
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis {

private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";

private static final String NUM_CLASSES = "num_classes";

private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);

/**
* The max number of classes classification supports
*/
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;

private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
Expand Down Expand Up @@ -218,7 +225,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
Expand All @@ -227,10 +234,11 @@ public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
return params;
}

Expand Down Expand Up @@ -272,7 +280,7 @@ public List<RequiredField> getRequiredFields() {
@Override
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY));
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;

Expand All @@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {

/**
* @return The analysis parameters as a map
* @param extractedFields map of (name, types) for all the extracted fields
* @param fieldInfo Information about the fields like types and cardinalities
*/
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
Map<String, Object> getParams(FieldInfo fieldInfo);

/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
Expand Down Expand Up @@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
* Returns the document id for the analysis state
*/
String getStateDocId(String jobId);

/**
* Summarizes information about the fields that is necessary for analysis to generate
* the parameters needed for the process configuration.
*/
interface FieldInfo {

/**
* Returns the types for the given field or {@code null} if the field is unknown
* @param field the field whose types to return
* @return the types for the given field or {@code null} if the field is unknown
*/
@Nullable
Set<String> getTypes(String field);

/**
* Returns the cardinality of the given field or {@code null} if there is no cardinality for that field
* @param field the field whose cardinality to get
* @return the cardinality of the given field or {@code null} if there is no cardinality for that field
*/
@Nullable
Long getCardinality(String field);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public int hashCode() {
}

@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
if (nNeighbors != null) {
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,38 +187,46 @@ public void testGetTrainingPercent() {
}

public void testGetParams() {
Map<String, Set<String>> extractedFields =
DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(
Map.of(
"foo", Set.of(BooleanFieldMapper.CONTENT_TYPE),
"bar", Set.of(NumberFieldMapper.NumberType.LONG.typeName()),
"baz", Set.of(KeywordFieldMapper.CONTENT_TYPE));
"baz", Set.of(KeywordFieldMapper.CONTENT_TYPE)),
Map.of(
"foo", 10L,
"bar", 20L,
"baz", 30L)
);
assertThat(
new Classification("foo").getParams(extractedFields),
new Classification("foo").getParams(fieldInfo),
equalTo(
Map.of(
"dependent_variable", "foo",
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
"num_top_classes", 2,
"prediction_field_name", "foo_prediction",
"prediction_field_type", "bool")));
"prediction_field_type", "bool",
"num_classes", 10L)));
assertThat(
new Classification("bar").getParams(extractedFields),
new Classification("bar").getParams(fieldInfo),
equalTo(
Map.of(
"dependent_variable", "bar",
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
"num_top_classes", 2,
"prediction_field_name", "bar_prediction",
"prediction_field_type", "int")));
"prediction_field_type", "int",
"num_classes", 20L)));
assertThat(
new Classification("baz").getParams(extractedFields),
new Classification("baz").getParams(fieldInfo),
equalTo(
Map.of(
"dependent_variable", "baz",
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
"num_top_classes", 2,
"prediction_field_name", "baz_prediction",
"prediction_field_type", "string")));
"prediction_field_type", "string",
"num_classes", 30L)));
}

public void testRequiredFieldsIsNonEmpty() {
Expand All @@ -232,7 +240,7 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
assertThat(constraints.size(), equalTo(1));
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(30L));
}

public void testGetExplicitlyMappedFields() {
Expand Down Expand Up @@ -331,4 +339,25 @@ public void testExtractJobIdFromStateDoc() {
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
return mutateForVersion(instance, version);
}

private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo {

private final Map<String, Set<String>> fieldTypes;
private final Map<String, Long> fieldCardinalities;

private TestFieldInfo(Map<String, Set<String>> fieldTypes, Map<String, Long> fieldCardinalities) {
this.fieldTypes = fieldTypes;
this.fieldCardinalities = fieldCardinalities;
}

@Override
public Set<String> getTypes(String field) {
return fieldTypes.get(field);
}

@Override
public Long getCardinality(String field) {
return fieldCardinalities.get(field);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,11 @@ public void testStopAndRestart() throws Exception {
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

@AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes")
public void testDependentVariableCardinalityTooHighError() throws Exception {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);

// Index one more document with a class different than the two already used.
client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex)
.source(KEYWORD_FIELD, "fox")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand All @@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields {
private final ExtractedField timeField;

public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
super(allFields);
super(allFields, Collections.emptyMap());
if (!allFields.contains(timeField)) {
throw new IllegalArgumentException("timeField should also be contained in allFields");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ public class ExtractedFieldsDetector {
private final DataFrameAnalyticsConfig config;
private final int docValueFieldsLimit;
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
private final Map<String, Long> fieldCardinalities;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;

ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit,
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> fieldCardinalities) {
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.index = Objects.requireNonNull(index);
this.config = Objects.requireNonNull(config);
this.docValueFieldsLimit = docValueFieldsLimit;
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities);
this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints);
}

public Tuple<ExtractedFields, List<FieldSelection>> detect() {
Expand Down Expand Up @@ -286,12 +286,13 @@ private void checkRequiredFields(Set<String> fields) {

private void checkFieldsWithCardinalityLimit() {
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
constraint.check(fieldCardinalities.get(constraint.getField()));
constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField()));
}
}

private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse);
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
cardinalitiesForFieldsWithConstraints);
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
if (preferSource) {
Expand Down Expand Up @@ -321,7 +322,7 @@ private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields,
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
}
}
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
}

private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
Expand Down Expand Up @@ -372,7 +373,7 @@ private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFiel
for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
}
return new ExtractedFields(adjusted);
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
}

private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
Expand All @@ -389,7 +390,7 @@ private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFi
adjusted.add(field);
}
}
return new ExtractedFields(adjusted);
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
}

private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static java.util.stream.Collectors.toMap;

public class AnalyticsProcessConfig implements ToXContentObject {

private static final String JOB_ID = "job_id";
Expand Down Expand Up @@ -93,12 +92,31 @@ private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields ext
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("name", analysis.getWriteableName());
builder.field(
"parameters",
analysis.getParams(
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields)));
builder.endObject();
return builder;
}
}

private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo {

private final ExtractedFields extractedFields;

AnalysisFieldInfo(ExtractedFields extractedFields) {
this.extractedFields = Objects.requireNonNull(extractedFields);
}

@Override
public Set<String> getTypes(String field) {
Optional<ExtractedField> extractedField = extractedFields.getAllFields().stream()
.filter(f -> f.getName().equals(field))
.findAny();
return extractedField.isPresent() ? extractedField.get().getTypes() : null;
}

@Override
public Long getCardinality(String field) {
return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ public class ExtractedFields {
private final List<ExtractedField> allFields;
private final List<ExtractedField> docValueFields;
private final String[] sourceFields;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;

public ExtractedFields(List<ExtractedField> allFields) {
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.allFields = Collections.unmodifiableList(allFields);
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
.toArray(String[]::new);
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
}

public List<ExtractedField> getAllFields() {
Expand All @@ -48,14 +50,20 @@ public List<ExtractedField> getDocValueFields() {
return docValueFields;
}

public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
return cardinalitiesForFieldsWithConstraints;
}

private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
}

public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
FieldCapabilitiesResponse fieldsCapabilities) {
FieldCapabilitiesResponse fieldsCapabilities,
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()));
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
cardinalitiesForFieldsWithConstraints);
}

public static TimeField newTimeField(String name, ExtractedField.Method method) {
Expand Down
Loading

0 comments on commit 7e0a3c4

Please sign in to comment.