Skip to content

Commit

Permalink
[7.4][ML] Regression dependent variable must be numeric (elastic#46072)
Browse files Browse the repository at this point in the history
* [ML] Regression dependent variable must be numeric

This adds a validation that the dependent variable of a regression
analysis must be numeric.

* Address review comments and fix some problems

In addition to addressing the review comments, this
commit fixes a few issues I found during testing.

In particular:

- if there were mappings for required fields but they were
not included we were not reporting the error
- if explicitly included fields had unsupported types we were
not reporting the error

Unfortunately, I couldn't get those fixed without refactoring
the code in `ExtractedFieldsDetector`.
  • Loading branch information
dimitris-athanasiou committed Aug 30, 2019
1 parent 5b61708 commit 4d8eff8
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;

import java.util.List;
import java.util.Map;
import java.util.Set;

public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {

Expand All @@ -24,9 +24,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
boolean supportsCategoricalFields();

/**
* @return The set of fields that analyzed documents must have for the analysis to operate
* @return The names and types of the fields that analyzed documents must have for the analysis to operate
*/
Set<String> getRequiredFields();
List<RequiredField> getRequiredFields();

/**
* @return {@code true} if this analysis supports data frame rows with missing values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class OutlierDetection implements DataFrameAnalysis {

Expand Down Expand Up @@ -160,8 +160,8 @@ public boolean supportsCategoricalFields() {
}

@Override
public Set<String> getRequiredFields() {
return Collections.emptySet();
public List<RequiredField> getRequiredFields() {
return Collections.emptyList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class Regression implements DataFrameAnalysis {

Expand Down Expand Up @@ -201,8 +201,8 @@ public boolean supportsCategoricalFields() {
}

@Override
public Set<String> getRequiredFields() {
return Collections.singleton(dependentVariable);
public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import java.util.Collections;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

public class RequiredField {

private final String name;

/**
* The required field must have one of those types.
* We use a sorted set to ensure types are reported alphabetically in error messages.
*/
private final SortedSet<String> types;

public RequiredField(String name, Set<String> types) {
this.name = Objects.requireNonNull(name);
this.types = Collections.unmodifiableSortedSet(new TreeSet<>(types));
}

public String getName() {
return name;
}

public SortedSet<String> getTypes() {
return types;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.index.mapper.NumberFieldMapper;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Helper class that defines groups of types
*/
public final class Types {

private Types() {}

private static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));

private static final Set<String> NUMERICAL_TYPES;

static {
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet());
numericalTypes.add("scaled_float");
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
}

public static Set<String> categorical() {
return CATEGORICAL_TYPES;
}

public static Set<String> numerical() {
return NUMERICAL_TYPES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;

Expand Down Expand Up @@ -268,7 +269,7 @@ public Set<String> getCategoricalFields() {
Set<String> categoricalFields = new HashSet<>();
for (ExtractedField extractedField : context.extractedFields.getAllFields()) {
String fieldName = extractedField.getName();
if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) {
if (Types.categorical().containsAll(extractedField.getTypes())) {
categoricalFields.add(fieldName);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
Expand All @@ -35,10 +36,10 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ExtractedFieldsDetector {

Expand All @@ -50,18 +51,6 @@ public class ExtractedFieldsDetector {
private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no",
"_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY);

public static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));

private static final Set<String> NUMERICAL_TYPES;

static {
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet());
numericalTypes.add("scaled_float");
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
}

private final String[] index;
private final DataFrameAnalyticsConfig config;
private final String resultsField;
Expand All @@ -80,33 +69,32 @@ public class ExtractedFieldsDetector {
}

public ExtractedFields detect() {
Set<String> fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet());
fields.removeAll(IGNORE_FIELDS);
removeFieldsUnderResultsField(fields);
includeAndExcludeFields(fields);
removeFieldsWithIncompatibleTypes(fields);
checkRequiredFieldsArePresent(fields);
Set<String> fields = getIncludedFields();

if (fields.isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}. Supported types are {}.",
Arrays.toString(index),
getSupportedTypes());
}

List<String> sortedFields = new ArrayList<>(fields);
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
"please adjust the index level setting [{}]", extractedFields.getDocValueFields().size(), docValueFieldsLimit,
IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey());
}
checkNoIgnoredFields(fields);
checkFieldsHaveCompatibleTypes(fields);
checkRequiredFields(fields);
return detectExtractedFields(fields);
}

private Set<String> getIncludedFields() {
Set<String> fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet());
removeFieldsUnderResultsField(fields);
FetchSourceContext analyzedFields = config.getAnalyzedFields();

// If the user has not explicitly included fields we'll include all compatible fields
if (analyzedFields == null || analyzedFields.includes().length == 0) {
fields.removeAll(IGNORE_FIELDS);
removeFieldsWithIncompatibleTypes(fields);
}
extractedFields = fetchBooleanFieldsAsIntegers(extractedFields);
return extractedFields;
includeAndExcludeFields(fields);
return fields;
}

private void removeFieldsUnderResultsField(Set<String> fields) {
Expand Down Expand Up @@ -139,31 +127,38 @@ private void removeFieldsWithIncompatibleTypes(Set<String> fields) {
Iterator<String> fieldsIterator = fields.iterator();
while (fieldsIterator.hasNext()) {
String field = fieldsIterator.next();
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null) {
LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field);
if (hasCompatibleType(field) == false) {
fieldsIterator.remove();
} else {
Set<String> fieldTypes = fieldCaps.keySet();
if (NUMERICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
} else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
} else if (isBoolean(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is boolean", config.getId(), field);
} else {
LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}; supported {}",
config.getId(), field, fieldTypes, getSupportedTypes());
fieldsIterator.remove();
}
}
}
}

private boolean hasCompatibleType(String field) {
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null) {
LOGGER.debug("[{}] incompatible field [{}] because it is missing from mappings", config.getId(), field);
return false;
}
Set<String> fieldTypes = fieldCaps.keySet();
if (Types.numerical().containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
return true;
} else if (config.getAnalysis().supportsCategoricalFields() && Types.categorical().containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
return true;
} else if (isBoolean(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is boolean", config.getId(), field);
return true;
} else {
LOGGER.debug("[{}] incompatible field [{}]; types {}; supported {}", config.getId(), field, fieldTypes, getSupportedTypes());
return false;
}
}

private Set<String> getSupportedTypes() {
Set<String> supportedTypes = new TreeSet<>(NUMERICAL_TYPES);
Set<String> supportedTypes = new TreeSet<>(Types.numerical());
if (config.getAnalysis().supportsCategoricalFields()) {
supportedTypes.addAll(CATEGORICAL_TYPES);
supportedTypes.addAll(Types.categorical());
}
supportedTypes.add(BooleanFieldMapper.CONTENT_TYPE);
return supportedTypes;
Expand Down Expand Up @@ -202,16 +197,61 @@ private void includeAndExcludeFields(Set<String> fields) {
}
}

private void checkRequiredFieldsArePresent(Set<String> fields) {
List<String> missingFields = config.getAnalysis().getRequiredFields()
.stream()
.filter(f -> fields.contains(f) == false)
.collect(Collectors.toList());
if (missingFields.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields);
private void checkNoIgnoredFields(Set<String> fields) {
Optional<String> ignoreField = IGNORE_FIELDS.stream().filter(fields::contains).findFirst();
if (ignoreField.isPresent()) {
throw ExceptionsHelper.badRequestException("field [{}] cannot be analyzed", ignoreField.get());
}
}

private void checkFieldsHaveCompatibleTypes(Set<String> fields) {
for (String field : fields) {
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null) {
throw ExceptionsHelper.badRequestException("no mappings could be found for field [{}]", field);
}

if (hasCompatibleType(field) == false) {
throw ExceptionsHelper.badRequestException("field [{}] has unsupported type {}. Supported types are {}.", field,
fieldCaps.keySet(), getSupportedTypes());
}
}
}

private void checkRequiredFields(Set<String> fields) {
List<RequiredField> requiredFields = config.getAnalysis().getRequiredFields();
for (RequiredField requiredField : requiredFields) {
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(requiredField.getName());
if (fields.contains(requiredField.getName()) == false || fieldCaps == null || fieldCaps.isEmpty()) {
List<String> requiredFieldNames = requiredFields.stream().map(RequiredField::getName).collect(Collectors.toList());
throw ExceptionsHelper.badRequestException("required field [{}] is missing; analysis requires fields {}",
requiredField.getName(), requiredFieldNames);
}
Set<String> fieldTypes = fieldCaps.keySet();
if (requiredField.getTypes().containsAll(fieldTypes) == false) {
throw ExceptionsHelper.badRequestException("invalid types {} for required field [{}]; expected types are {}",
fieldTypes, requiredField.getName(), requiredField.getTypes());
}
}
}

private ExtractedFields detectExtractedFields(Set<String> fields) {
List<String> sortedFields = new ArrayList<>(fields);
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
"please adjust the index level setting [{}]", extractedFields.getDocValueFields().size(), docValueFieldsLimit,
IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey());
}
}
extractedFields = fetchBooleanFieldsAsIntegers(extractedFields);
return extractedFields;
}

private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) {
Expand Down
Loading

0 comments on commit 4d8eff8

Please sign in to comment.