Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ml inference ingest processor support for local models #2508

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,10 @@ public void loadExtensions(ExtensionLoader loader) {
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
processors
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
.put(
MLInferenceIngestProcessor.TYPE,
new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need the xContentRegistry passed from the plugin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add it here because it can be added as dependency argument to MLInferenceIngestProcessor when its instantiated or created in Factory.

);
return Collections.unmodifiableMap(processors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,31 @@

import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -42,20 +48,31 @@
*/
public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {

private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class);

public static final String DOT_SYMBOL = ".";
private final InferenceProcessorAttributes inferenceProcessorAttributes;
private final boolean ignoreMissing;
private final String functionName;
private final boolean fullResponsePath;
private final boolean ignoreFailure;
private final boolean override;
private final String modelInput;
private final ScriptService scriptService;
private static Client client;
public static final String TYPE = "ml_inference";
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
// prediction outcomes, return the whole prediction outcome by skipping filtering
public static final String IGNORE_MISSING = "ignore_missing";
public static final String OVERRIDE = "override";
public static final String FUNCTION_NAME = "function_name";
public static final String FULL_RESPONSE_PATH = "full_response_path";
public static final String MODEL_INPUT = "model_input";
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
Expand All @@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor(
String tag,
String description,
boolean ignoreMissing,
String functionName,
boolean fullResponsePath,
boolean ignoreFailure,
boolean override,
String modelInput,
ScriptService scriptService,
Client client
Client client,
NamedXContentRegistry xContentRegistry
) {
super(tag, description);
this.inferenceProcessorAttributes = new InferenceProcessorAttributes(
Expand All @@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor(
maxPredictionTask
);
this.ignoreMissing = ignoreMissing;
this.functionName = functionName;
this.fullResponsePath = fullResponsePath;
this.ignoreFailure = ignoreFailure;
this.override = override;
this.modelInput = modelInput;
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand Down Expand Up @@ -162,10 +189,48 @@ private void processPredictions(
List<Map<String, String>> processOutputMap,
int inputMapIndex,
int inputMapSize
) {
) throws IOException {
Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
}

Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());

Map<String, List<String>> newOutputMapping = new HashMap<>();
if (processOutputMap != null) {

Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
}

for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPaths = newOutputMapping.get(newDocumentFieldName);

int existingFields = 0;
for (String path : dotPaths) {
if (ingestDocument.hasField(path)) {
existingFields++;
}
}
if (!override && existingFields == dotPaths.size()) {
logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName);
newOutputMapping.remove(newDocumentFieldName);
}
}
if (newOutputMapping.size() == 0) {
batchPredictionListener.onResponse(null);
return;
}
}
// when no input mapping is provided, default to read all fields from documents as model input
if (inputMapSize == 0) {
Expand All @@ -184,15 +249,30 @@ private void processPredictions(
}
}

ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId());
Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
inputMapKeys.removeAll(modelConfigs.keySet());
rbhavna marked this conversation as resolved.
Show resolved Hide resolved

Map<String, String> inputMappings = new HashMap<>();
for (String k : inputMapKeys) {
inputMappings.put(k, modelParameters.get(k));
}
ActionRequest request = getMLModelInferenceRequest(
xContentRegistry,
modelParameters,
modelConfigs,
inputMappings,
inferenceProcessorAttributes.getModelId(),
functionName,
modelInput
);

client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() {

@Override
public void onResponse(MLTaskResponse mlTaskResponse) {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
MLOutput mlOutput = mlTaskResponse.getOutput();
if (processOutputMap == null || processOutputMap.isEmpty()) {
appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
} else {
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
// in the outputMapping will not write to document
Expand All @@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
// document field as key, model field as value
String newDocumentFieldName = entry.getKey();
String modelOutputFieldName = entry.getValue();
if (ingestDocument.hasField(newDocumentFieldName)) {
throw new IllegalArgumentException(
"document already has field name "
+ newDocumentFieldName
+ ". Not allow to overwrite the same field name, please check output_map."
);
if (!newOutputMapping.containsKey(newDocumentFieldName)) {
mingshl marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
}
}
batchPredictionListener.onResponse(null);
Expand Down Expand Up @@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN
/**
* Appends the model output value to the specified field in the IngestDocument without modifying the source.
*
* @param modelTensorOutput the ModelTensorOutput containing the model output
* @param mlOutput the MLOutput containing the model output
* @param modelOutputFieldName the name of the field in the model output
* @param newDocumentFieldName the name of the field in the IngestDocument to append the value to
* @param ingestDocument the IngestDocument to append the value to
*/
private void appendFieldValue(
ModelTensorOutput modelTensorOutput,
MLOutput mlOutput,
String modelOutputFieldName,
String newDocumentFieldName,
IngestDocument ingestDocument
) {
Object modelOutputValue = null;

if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) {
if (mlOutput == null) {
throw new RuntimeException("model inference output is null");
}

modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);

Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);

if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
} else {
if (!(modelOutputValue instanceof List)) {
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
}
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
// check length of the prediction array to be the same of the document array
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
throw new RuntimeException(
"the prediction field: "
+ modelOutputFieldName
+ " is an array in size of "
+ modelOutputValueArray.size()
+ " but the document field array from field "
+ newDocumentFieldName
+ " is in size of "
+ dotPathsInArray.size()
);
}
// Iterate over dotPathInArray
for (int i = 0; i < dotPathsInArray.size(); i++) {
String dotPathInArray = dotPathsInArray.get(i);
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
} else {
if (!(modelOutputValue instanceof List)) {
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
}
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
// check length of the prediction array to be the same of the document array
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
throw new RuntimeException(
"the prediction field: "
+ modelOutputFieldName
+ " is an array in size of "
+ modelOutputValueArray.size()
+ " but the document field array from field "
+ newDocumentFieldName
+ " is in size of "
+ dotPathsInArray.size()
);
}
// Iterate over dotPathInArray
for (int i = 0; i < dotPathsInArray.size(); i++) {
String dotPathInArray = dotPathsInArray.get(i);
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
}
}
} else {
throw new RuntimeException("model inference output cannot be null");
}
}

Expand All @@ -374,16 +448,18 @@ public static class Factory implements Processor.Factory {

private final ScriptService scriptService;
private final Client client;
private final NamedXContentRegistry xContentRegistry;

/**
* Constructs a new instance of the Factory class.
*
* @param scriptService the ScriptService instance to be used by the Factory
* @param client the Client instance to be used by the Factory
*/
public Factory(ScriptService scriptService, Client client) {
public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand All @@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create(
int maxPredictionTask = ConfigurationUtils
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean defaultValue = !functionName.equals("remote");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);

boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
// convert model config user input data structure to Map<String, String>
Expand Down Expand Up @@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create(
processorTag,
description,
ignoreMissing,
functionName,
fullResponsePath,
ignoreFailure,
override,
modelInput,
scriptService,
client
client,
xContentRegistry
);
}
}
Expand Down
Loading
Loading