Skip to content

Commit

Permalink
[ML] Adds feature importance to option to inference processor (#52218) (
Browse files Browse the repository at this point in the history
#52666)

This adds machine learning model feature importance calculations to the inference processor.

The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values`
Example:
```
"inference": {
   "field_mappings": {},
   "model_id": "my_model",
   "inference_config": {
      "regression": {
         "num_top_feature_importance_values": 3
      }
   }
}
```

This will write to the document as follows:
```
"inference" : {
   "feature_importance" : {
      "FlightTimeMin" : -76.90955548511226,
      "FlightDelayType" : 114.13514762158526,
      "DistanceMiles" : 13.731580450792187
   },
   "predicted_value" : 108.33165831875137,
   "model_id" : "my_model"
}
```

This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888).

It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7.

Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded.

NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc

usability blocked by: elastic/ml-cpp#991
  • Loading branch information
benwtrent authored Feb 21, 2020
1 parent f06d692 commit afd9064
Show file tree
Hide file tree
Showing 29 changed files with 980 additions and 104 deletions.
12 changes: 12 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ include::common-options.asciidoc[]
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.

`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.

[discrete]
[[inference-processor-classification-opt]]
Expand All @@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
Specifies the field to which the top classes are written. Defaults to
`top_classes`.

`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.

[discrete]
[[inference-processor-config-example]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -73,6 +74,7 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;

private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
Expand Down Expand Up @@ -115,13 +117,35 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

private void preProcess(Map<String, Object> fields) {
void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}

public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, config);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}

private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -235,6 +236,11 @@ public void process(Map<String, Object> fields) {
fields.put(destField, concatEmbeddings(processedFeatures));
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(destField, fieldName);
}

@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
Expand Down Expand Up @@ -80,6 +82,11 @@ public Map<String, String> getHotMap() {
return hotMap;
}

@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @param fields The fields and their values to process
*/
void process(Map<String, Object> fields);

/**
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,25 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
super(value);
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
Expand Down Expand Up @@ -74,16 +90,17 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(resultsField, that.resultsField) &&
Objects.equals(topNumClassesField, that.topNumClassesField) &&
Objects.equals(topClasses, that.topClasses);
return Objects.equals(value(), that.value())
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
}

@Override
Expand All @@ -100,6 +117,9 @@ public void writeResult(IngestDocument document, String parentResultField) {
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
import org.elasticsearch.ingest.IngestDocument;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class RawInferenceResults extends SingleValueInferenceResults {

public static final String NAME = "raw";

public RawInferenceResults(double value) {
super(value);
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
}

public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
}

@Override
Expand All @@ -34,12 +35,13 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value());
return Objects.hash(value(), getFeatureImportance());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class RegressionInferenceResults extends SingleValueInferenceResults {
Expand All @@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
super(value);
assert config instanceof RegressionConfig;
RegressionConfig regressionConfig = (RegressionConfig)config;
this(value, (RegressionConfig) config, Collections.emptyMap());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField();
}

public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
this.resultsField = in.readString();
}

Expand All @@ -44,19 +54,24 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), resultsField);
return Objects.hash(value(), resultsField, getFeatureImportance());
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,61 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.featureImportance = Collections.emptyMap();
}
}

SingleValueInferenceResults(double value) {
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
return featureImportance;
}

public String valueAsString() {
return String.valueOf(value);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

}
Loading

0 comments on commit afd9064

Please sign in to comment.