Skip to content

Commit

Permalink
[7.x] Pipeline Inference Aggregation (elastic#58965)
Browse files Browse the repository at this point in the history
Adds a pipeline aggregation that loads a model and performs inference on the
input aggregation results.
  • Loading branch information
davidkyle authored Jul 3, 2020
1 parent d22dd43 commit f6a0c2c
Show file tree
Hide file tree
Showing 45 changed files with 2,127 additions and 189 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[role="xpack"]
[testenv="basic"]
[[search-aggregations-pipeline-inference-bucket-aggregation]]
=== Inference Bucket Aggregation

A parent pipeline aggregation which loads a pre-trained model and performs inference on the
collated result field from the parent bucket aggregation.

[[inference-bucket-agg-syntax]]
==== Syntax

A `inference` aggregation looks like this in isolation:

[source,js]
--------------------------------------------------
{
"inference": {
"model_id": "a_model_for_inference", <1>
"inference_config": { <2>
"regression_config": {
"num_top_feature_importance_values": 2
}
},
"buckets_path": {
"avg_cost": "avg_agg", <3>
"max_cost": "max_agg"
}
}
}
--------------------------------------------------
// NOTCONSOLE
<1> The ID of model to use.
<2> The optional inference config which overrides the model's default settings
<3> Map the value of `avg_agg` to the model's input field `avg_cost`

[[inference-bucket-params]]
.`inference` Parameters
[options="header"]
|===
|Parameter Name |Description |Required |Default Value
| `model_id` | The ID of the model to load and infer against | Required | -
| `inference_config` | Contains the inference type and its options. There are two types: <<inference-agg-regression-opt,`regression`>> and <<inference-agg-classification-opt,`classification`>> | Optional | -
| `buckets_path` | Defines the paths to the input aggregations and maps the aggregation names to the field names expected by the model.
See <<buckets-path-syntax>> for more details | Required | -
|===


==== Configuration options for {infer} models
The `inference_config` setting is optional and usaully isn't required as the pre-trained models come equipped with sensible defaults.
In the context of aggregations some options can overridden for each of the 2 types of model.

[discrete]
[[inference-agg-regression-opt]]
===== Configuration options for {regression} models

`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]

[discrete]
[[inference-agg-classification-opt]]
===== Configuration options for {classification} models

`num_top_classes`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]

`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]

`prediction_field_type`::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ public void setUp() throws Exception {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(indicesModule.getNamedWriteables());
entries.addAll(searchModule.getNamedWriteables());
entries.addAll(additionalNamedWriteables());
namedWriteableRegistry = new NamedWriteableRegistry(entries);
xContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
List<NamedXContentRegistry.Entry> xContentEntries = searchModule.getNamedXContents();
xContentEntries.addAll(additionalNamedContents());
xContentRegistry = new NamedXContentRegistry(xContentEntries);
//create some random type with some default field, those types will stick around for all of the subclasses
currentTypes = new String[randomIntBetween(0, 5)];
for (int i = 0; i < currentTypes.length; i++) {
Expand All @@ -101,6 +104,20 @@ protected List<SearchPlugin> plugins() {
return emptyList();
}

/**
* Any extra named writeables required not registered by {@link SearchModule}
*/
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
return emptyList();
}

/**
* Any extra named xcontents required not registered by {@link SearchModule}
*/
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
return emptyList();
}

/**
* Generic test that creates new AggregatorFactory from the test
* AggregatorFactory and checks both for equality and asserts equality on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
Expand Down Expand Up @@ -121,6 +123,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
ClassificationConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
RegressionConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME,
ResultsFieldUpdate::fromXContent));

// Inference models
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
Expand Down Expand Up @@ -170,6 +174,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
WarningInferenceResults.NAME,
WarningInferenceResults::new));

// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -18,9 +17,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -85,6 +82,10 @@ public List<TopClassEntry> getTopClasses() {
return topClasses;
}

public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand Down Expand Up @@ -127,6 +128,11 @@ public String valueAsString() {
return classificationLabel == null ? super.valueAsString() : classificationLabel;
}

@Override
public Object predictedValue() {
return predictionFieldType.transformPredictedValue(value(), valueAsString());
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
Expand All @@ -138,7 +144,7 @@ public void writeResult(IngestDocument document, String parentResultField) {
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
Expand All @@ -150,74 +156,15 @@ public String getWriteableName() {
return NAME;
}

public static class TopClassEntry implements Writeable {

public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");

private final Object classification;
private final double probability;
private final double score;

public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
}

public TopClassEntry(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}

public Object getClassification() {
return classification;
}

public double getProbability() {
return probability;
}

public double getScore() {
return score;
}

public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
if (topClasses.size() > 0) {
builder.field(topNumClassesField, topClasses);
}

@Override
public int hashCode() {
return Objects.hash(classification, probability, score);
if (getFeatureImportance().size() > 0) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
}
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,33 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class FeatureImportance implements Writeable {
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class FeatureImportance implements Writeable, ToXContentObject {

private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name";
static final String CLASS_IMPORTANCE = "class_importance";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
Expand All @@ -31,7 +41,24 @@ public static FeatureImportance forClassification(String featureName, Map<String
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}

private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
);

static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
}

public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
Expand Down Expand Up @@ -79,6 +106,22 @@ public Map<String, Object> toMap() {
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
Expand All @@ -93,5 +136,4 @@ public boolean equals(Object object) {
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.ingest.IngestDocument;

public interface InferenceResults extends NamedWriteable {
public interface InferenceResults extends NamedWriteable, ToXContentFragment {

void writeResult(IngestDocument document, String parentResultField);

Object predictedValue();
}
Loading

0 comments on commit f6a0c2c

Please sign in to comment.