Skip to content

Commit

Permalink
[ML] Simplify the Inference Ingest Processor configuration (#100205)
Browse files Browse the repository at this point in the history
Adds a `input_ouput` option the removes the need for a `field_map` and/or
target fields. Multiple inputs can be specified in `input_output`
  • Loading branch information
davidkyle authored Oct 3, 2023
1 parent a5f65a5 commit b055204
Show file tree
Hide file tree
Showing 23 changed files with 1,054 additions and 78 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/100205.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 100205
summary: Simplify the Inference Ingest Processor configuration
area: Machine Learning
type: enhancement
issues: []
52 changes: 52 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,64 @@ ingested in the pipeline.
|======
| Name | Required | Default | Description
| `model_id` . | yes | - | (String) The ID or alias for the trained model, or the ID of the deployment.
| `input_output` | no | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options.
| `target_field` | no | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
| `field_map` | no | If defined the model's default field map | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
| `inference_config` | no | The default settings defined in the model | (Object) Contains the inference type and its options.
include::common-options.asciidoc[]
|======

[discrete]
[[inference-input-output-example]]
==== Configuring input and output fields
Select the `content` field for inference and write the result to `content_embedding`.

[source,js]
--------------------------------------------------
{
"inference": {
"model_id": "model_deployment_for_inference",
"input_output": [
{
"input_field": "content",
"output_field": "content_embedding"
}
]
}
}
--------------------------------------------------
// NOTCONSOLE

==== Configuring multiple inputs
The `content` and `title` fields will be read from the incoming document
and sent to the model for the inference. The inference output is written
to `content_embedding` and `title_embedding` respectively.
[source,js]
--------------------------------------------------
{
"inference": {
"model_id": "model_deployment_for_inference",
"input_output": [
{
"input_field": "content",
"output_field": "content_embedding"
},
{
"input_field": "title",
"output_field": "title_embedding"
}
]
}
}
--------------------------------------------------
// NOTCONSOLE

Selecting the input fields with `input_output` is incompatible with
the `target_field` and `field_map` options.

Data frame analytics models must use the `target_field` to specify the
root location results are written to and optionally a `field_map` to map
field names in the input document to the model input fields.

[source,js]
--------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xcontent.ToXContentFragment;

Expand All @@ -24,16 +25,58 @@ static void writeResult(InferenceResults results, IngestDocument ingestDocument,
Objects.requireNonNull(resultField, "resultField");
Map<String, Object> resultMap = results.asMap();
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
if (ingestDocument.hasField(resultField)) {
ingestDocument.appendFieldValue(resultField, resultMap);
setOrAppendValue(resultField, resultMap, ingestDocument);
}

static void writeResultToField(
InferenceResults results,
IngestDocument ingestDocument,
@Nullable String basePath,
String outputField,
String modelId,
boolean includeModelId
) {
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(outputField, "outputField");
Map<String, Object> resultMap = results.asMap(outputField);
if (includeModelId) {
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
}
if (basePath == null) {
// insert the results into the root of the document
for (var entry : resultMap.entrySet()) {
setOrAppendValue(entry.getKey(), entry.getValue(), ingestDocument);
}
} else {
ingestDocument.setFieldValue(resultField, resultMap);
for (var entry : resultMap.entrySet()) {
setOrAppendValue(basePath + "." + entry.getKey(), entry.getValue(), ingestDocument);
}
}
}

private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) {
if (ingestDocument.hasField(path)) {
ingestDocument.appendFieldValue(path, value);
} else {
ingestDocument.setFieldValue(path, value);
}
}

String getResultsField();

/**
* Convert to a map
* @return Map representation of the InferenceResult
*/
Map<String, Object> asMap();

/**
* Convert to a map placing the inference result in {@code outputField}
* @param outputField Write the inference result to this field
* @return Map representation of the InferenceResult
*/
Map<String, Object> asMap(String outputField);

Object predictedValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ public String getResultsField() {
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
addSupportingFieldsToMap(map);
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
map.put(outputField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (topClasses.isEmpty() == false) {
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
Expand All @@ -235,7 +248,6 @@ public Map<String, Object> asMap() {
featureImportance.stream().map(ClassificationFeatureImportance::toMap).collect(Collectors.toList())
);
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public Map<String, Object> asMap() {
return asMap;
}

@Override
public Map<String, Object> asMap(String outputField) {
// errors do not have a result
return asMap();
}

@Override
public String toString() {
return Strings.toString(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField + "_sequence", predictedSequence);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField + "_sequence", predictedSequence);
return map;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,21 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)
public final Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
addMapFields(map);
addSupportingFieldsToMap(map);
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (isTruncated) {
map.put("is_truncated", isTruncated);
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, inference);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, inference);
return map;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public String predictedValue() {
@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, answer);
addSupportingFieldsToMap(map);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, answer);
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
map.put(START_OFFSET.getPreferredName(), startOffset);
map.put(END_OFFSET.getPreferredName(), endOffset);
if (topClasses.isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public Map<String, Object> asMap() {
throw new UnsupportedOperationException("[raw] does not support map conversion");
}

@Override
public Map<String, Object> asMap(String outputField) {
throw new UnsupportedOperationException("[raw] does not support map conversion");
}

@Override
public Object predictedValue() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,23 @@ public String getResultsField() {
@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
map.put(resultsField, value());
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
map.put(outputField, value());
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (featureImportance.isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, inference);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, inference);
return map;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,11 @@ void doWriteTo(StreamOutput out) throws IOException {
void addMapFields(Map<String, Object> map) {
map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, score);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, score);
return map;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public Map<String, Object> asMap() {
return asMap;
}

@Override
public Map<String, Object> asMap(String outputField) {
// warnings do not have a result
return asMap();
}

@Override
public Object predictedValue() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str

public static final ParseField NAME = new ParseField("classification");

public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {

String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
String DEFAULT_RESULTS_FIELD = "predicted_value";
ParseField RESULTS_FIELD = new ParseField("results_field");

boolean isTargetTypeSupported(TargetType targetType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse
ParseField VOCABULARY = new ParseField("vocabulary");
ParseField TOKENIZATION = new ParseField("tokenization");
ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
ParseField RESULTS_FIELD = new ParseField("results_field");
ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");

MlConfigVersion MINIMUM_NLP_SUPPORTED_VERSION = MlConfigVersion.V_8_0_0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ public class RegressionConfig implements LenientlyParsedInferenceConfig, Strictl
public static final ParseField NAME = new ParseField("regression");
private static final MlConfigVersion MIN_SUPPORTED_VERSION = MlConfigVersion.V_7_6_0;
private static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.V_7_6_0;
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");

public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
Expand Down
Loading

0 comments on commit b055204

Please sign in to comment.