Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.16] Fix ml inference ingest processor always return list using JsonPath #3006

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
compileOnly 'com.jayway.jsonpath:json-path:2.9.0'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
import com.jayway.jsonpath.JsonPath;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class StringUtils {
Expand Down Expand Up @@ -239,6 +244,46 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
}
}

public static String obtainFieldNameFromJsonPath(String jsonPath) {
String[] parts = jsonPath.split("\\.");

// Get the last part which is the field name
return parts[parts.length - 1];
}

public static String getJsonPath(String jsonPathWithSource) {
// Find the index of the first occurrence of "$."
int startIndex = jsonPathWithSource.indexOf("$.");

// Extract the substring from the startIndex to the end of the input string
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
}

/**
* Checks if the given input string matches the JSONPath format.
*
* <p>The JSONPath format is a way to navigate and extract data from JSON documents.
* It uses a syntax similar to XPath for XML documents. This method attempts to compile
* the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath}
* library. If the compilation succeeds, it means the input string is a valid JSONPath
* expression.
*
* @param input the input string to be checked for JSONPath format validity
* @return true if the input string is a valid JSONPath expression, false otherwise
*/
public static boolean isValidJSONPath(String input) {
if (input == null || input.isBlank()) {
return false;
}
try {
JsonPath.compile(input); // This will throw an exception if the path is invalid
return true;
} catch (Exception e) {
return false;
}
}


/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -37,9 +36,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;

/**
* MLInferenceIngestProcessor requires a modelId string to call model inferences
Expand Down Expand Up @@ -75,11 +72,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST)
.build();

protected MLInferenceIngestProcessor(
String modelId,
List<Map<String, String>> inputMaps,
Expand Down Expand Up @@ -320,24 +312,29 @@ private void getMappedModelInputFromDocuments(
Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
String documentFieldValueAsString = toString(documentFieldValue);
updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
return;
}
// else when cannot find field path in document, try check for nested array using json path
else {
if (documentFieldName.contains(DOT_SYMBOL)) {

Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
ArrayList<Object> fieldValueList = JsonPath
.using(suppressExceptionConfiguration)
.parse(sourceObject)
.read(documentFieldName);
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
// If the standard dot path fails, try to check for a nested array using JSON path
if (StringUtils.isValidJSONPath(documentFieldName)) {
Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName);

if (fieldValue != null) {
if (fieldValue instanceof List) {
List<?> fieldValueList = (List<?>) fieldValue;
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters);
}
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
}

Expand Down
Loading
Loading