Skip to content

Commit

Permalink
[ML] Adds feature importance to option to inference processor (#52218)
Browse files Browse the repository at this point in the history
This adds machine learning model feature importance calculations to the inference processor. 

The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values`
Example:
```
"inference": {
   "field_mappings": {},
   "model_id": "my_model",
   "inference_config": {
      "regression": {
         "num_top_feature_importance_values": 3
      }
   }
}
```

This will write to the document as follows:
```
"inference" : {
   "feature_importance" : { 
      "FlightTimeMin" : -76.90955548511226,
      "FlightDelayType" : 114.13514762158526,
      "DistanceMiles" : 13.731580450792187
   },
   "predicted_value" : 108.33165831875137,
   "model_id" : "my_model"
}
```

This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888). 

It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7. 

Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded.

NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc

usability blocked by: elastic/ml-cpp#991
  • Loading branch information
benwtrent authored Feb 21, 2020
1 parent 01a6dae commit 20f5427
Show file tree
Hide file tree
Showing 29 changed files with 980 additions and 104 deletions.
12 changes: 12 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ include::common-options.asciidoc[]
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.

`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.

[discrete]
[[inference-processor-classification-opt]]
Expand All @@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
Specifies the field to which the top classes are written. Defaults to
`top_classes`.

`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.

[discrete]
[[inference-processor-config-example]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -73,6 +74,7 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;

private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
Expand Down Expand Up @@ -115,13 +117,35 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

private void preProcess(Map<String, Object> fields) {
void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}

public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, config);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}

private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -235,6 +236,11 @@ public void process(Map<String, Object> fields) {
fields.put(destField, concatEmbeddings(processedFeatures));
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(destField, fieldName);
}

@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
Expand Down Expand Up @@ -80,6 +82,11 @@ public Map<String, String> getHotMap() {
return hotMap;
}

@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @param fields The fields and their values to process
*/
void process(Map<String, Object> fields);

/**
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,25 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
super(value);
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
Expand Down Expand Up @@ -74,16 +90,17 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(resultsField, that.resultsField) &&
Objects.equals(topNumClassesField, that.topNumClassesField) &&
Objects.equals(topClasses, that.topClasses);
return Objects.equals(value(), that.value())
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
}

@Override
Expand All @@ -100,6 +117,9 @@ public void writeResult(IngestDocument document, String parentResultField) {
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
import org.elasticsearch.ingest.IngestDocument;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class RawInferenceResults extends SingleValueInferenceResults {

public static final String NAME = "raw";

public RawInferenceResults(double value) {
super(value);
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
}

public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
}

@Override
Expand All @@ -34,12 +35,13 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value());
return Objects.hash(value(), getFeatureImportance());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class RegressionInferenceResults extends SingleValueInferenceResults {
Expand All @@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
super(value);
assert config instanceof RegressionConfig;
RegressionConfig regressionConfig = (RegressionConfig)config;
this(value, (RegressionConfig) config, Collections.emptyMap());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField();
}

public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
this.resultsField = in.readString();
}

Expand All @@ -44,19 +54,24 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), resultsField);
return Objects.hash(value(), resultsField, getFeatureImportance());
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,61 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.featureImportance = Collections.emptyMap();
}
}

SingleValueInferenceResults(double value) {
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
return featureImportance;
}

public String valueAsString() {
return String.valueOf(value);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

}
Loading

0 comments on commit 20f5427

Please sign in to comment.