Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.x] Pipeline Inference Aggregation #58965

Merged
merged 2 commits into from
Jul 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[role="xpack"]
[testenv="basic"]
[[search-aggregations-pipeline-inference-bucket-aggregation]]
=== Inference Bucket Aggregation

A parent pipeline aggregation which loads a pre-trained model and performs inference on the
collated result field from the parent bucket aggregation.

[[inference-bucket-agg-syntax]]
==== Syntax

A `inference` aggregation looks like this in isolation:

[source,js]
--------------------------------------------------
{
"inference": {
"model_id": "a_model_for_inference", <1>
"inference_config": { <2>
"regression_config": {
"num_top_feature_importance_values": 2
}
},
"buckets_path": {
"avg_cost": "avg_agg", <3>
"max_cost": "max_agg"
}
}
}
--------------------------------------------------
// NOTCONSOLE
<1> The ID of model to use.
<2> The optional inference config which overrides the model's default settings
<3> Map the value of `avg_agg` to the model's input field `avg_cost`

[[inference-bucket-params]]
.`inference` Parameters
[options="header"]
|===
|Parameter Name |Description |Required |Default Value
| `model_id` | The ID of the model to load and infer against | Required | -
| `inference_config` | Contains the inference type and its options. There are two types: <<inference-agg-regression-opt,`regression`>> and <<inference-agg-classification-opt,`classification`>> | Optional | -
| `buckets_path` | Defines the paths to the input aggregations and maps the aggregation names to the field names expected by the model.
See <<buckets-path-syntax>> for more details | Required | -
|===


==== Configuration options for {infer} models
The `inference_config` setting is optional and usaully isn't required as the pre-trained models come equipped with sensible defaults.
In the context of aggregations some options can overridden for each of the 2 types of model.

[discrete]
[[inference-agg-regression-opt]]
===== Configuration options for {regression} models

`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]

[discrete]
[[inference-agg-classification-opt]]
===== Configuration options for {classification} models

`num_top_classes`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]

`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]

`prediction_field_type`::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ public void setUp() throws Exception {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(indicesModule.getNamedWriteables());
entries.addAll(searchModule.getNamedWriteables());
entries.addAll(additionalNamedWriteables());
namedWriteableRegistry = new NamedWriteableRegistry(entries);
xContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
List<NamedXContentRegistry.Entry> xContentEntries = searchModule.getNamedXContents();
xContentEntries.addAll(additionalNamedContents());
xContentRegistry = new NamedXContentRegistry(xContentEntries);
//create some random type with some default field, those types will stick around for all of the subclasses
currentTypes = new String[randomIntBetween(0, 5)];
for (int i = 0; i < currentTypes.length; i++) {
Expand All @@ -101,6 +104,20 @@ protected List<SearchPlugin> plugins() {
return emptyList();
}

/**
* Any extra named writeables required not registered by {@link SearchModule}
*/
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
return emptyList();
}

/**
* Any extra named xcontents required not registered by {@link SearchModule}
*/
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
return emptyList();
}

/**
* Generic test that creates new AggregatorFactory from the test
* AggregatorFactory and checks both for equality and asserts equality on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
Expand Down Expand Up @@ -121,6 +123,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
ClassificationConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
RegressionConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME,
ResultsFieldUpdate::fromXContent));

// Inference models
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
Expand Down Expand Up @@ -170,6 +174,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
WarningInferenceResults.NAME,
WarningInferenceResults::new));

// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -18,9 +17,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -85,6 +82,10 @@ public List<TopClassEntry> getTopClasses() {
return topClasses;
}

public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand Down Expand Up @@ -127,6 +128,11 @@ public String valueAsString() {
return classificationLabel == null ? super.valueAsString() : classificationLabel;
}

@Override
public Object predictedValue() {
return predictionFieldType.transformPredictedValue(value(), valueAsString());
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
Expand All @@ -138,7 +144,7 @@ public void writeResult(IngestDocument document, String parentResultField) {
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
Expand All @@ -150,74 +156,15 @@ public String getWriteableName() {
return NAME;
}

public static class TopClassEntry implements Writeable {

public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");

private final Object classification;
private final double probability;
private final double score;

public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
}

public TopClassEntry(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}

public Object getClassification() {
return classification;
}

public double getProbability() {
return probability;
}

public double getScore() {
return score;
}

public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
if (topClasses.size() > 0) {
builder.field(topNumClassesField, topClasses);
}

@Override
public int hashCode() {
return Objects.hash(classification, probability, score);
if (getFeatureImportance().size() > 0) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
}
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,33 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class FeatureImportance implements Writeable {
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class FeatureImportance implements Writeable, ToXContentObject {

private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name";
static final String CLASS_IMPORTANCE = "class_importance";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
Expand All @@ -31,7 +41,24 @@ public static FeatureImportance forClassification(String featureName, Map<String
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}

private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
);

static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
}

public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
Expand Down Expand Up @@ -79,6 +106,22 @@ public Map<String, Object> toMap() {
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
Expand All @@ -93,5 +136,4 @@ public boolean equals(Object object) {
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.ingest.IngestDocument;

public interface InferenceResults extends NamedWriteable {
public interface InferenceResults extends NamedWriteable, ToXContentFragment {

void writeResult(IngestDocument document, String parentResultField);

Object predictedValue();
}
Loading