diff --git a/common/build.gradle b/common/build.gradle
index dc3337370f..8b81080a3b 100644
--- a/common/build.gradle
+++ b/common/build.gradle
@@ -35,6 +35,7 @@ dependencies {
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
+ compileOnly 'com.jayway.jsonpath:json-path:2.9.0'
}
lombok {
diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java
index 4bf74de3a9..37bfac6f3f 100644
--- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java
+++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java
@@ -28,6 +28,7 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
+import com.jayway.jsonpath.JsonPath;
import lombok.extern.log4j.Log4j2;
@@ -293,4 +294,29 @@ public static String getJsonPath(String jsonPathWithSource) {
// Extract the substring from the startIndex to the end of the input string
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
}
+
+ /**
+ * Checks if the given input string matches the JSONPath format.
+ *
+ *
The JSONPath format is a way to navigate and extract data from JSON documents.
+ * It uses a syntax similar to XPath for XML documents. This method attempts to compile
+ * the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath}
+ * library. If the compilation succeeds, it means the input string is a valid JSONPath
+ * expression.
+ *
+ * @param input the input string to be checked for JSONPath format validity
+ * @return true if the input string is a valid JSONPath expression, false otherwise
+ */
+ public static boolean isValidJSONPath(String input) {
+ if (input == null || input.isBlank()) {
+ return false;
+ }
+ try {
+ JsonPath.compile(input); // This will throw an exception if the path is invalid
+ return true;
+ } catch (Exception e) {
+ return false;
+ }
+ }
+
}
diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java
index 48e0464fbc..5424746d1a 100644
--- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java
+++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java
@@ -7,7 +7,6 @@
import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
@@ -37,9 +36,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;
-import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
-import com.jayway.jsonpath.Option;
/**
* MLInferenceIngestProcessor requires a modelId string to call model inferences
@@ -75,11 +72,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
private final NamedXContentRegistry xContentRegistry;
- private Configuration suppressExceptionConfiguration = Configuration
- .builder()
- .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST)
- .build();
-
protected MLInferenceIngestProcessor(
String modelId,
List> inputMaps,
@@ -320,24 +312,29 @@ private void getMappedModelInputFromDocuments(
Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
String documentFieldValueAsString = toString(documentFieldValue);
updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
+ return;
}
- // else when cannot find field path in document, try check for nested array using json path
- else {
- if (documentFieldName.contains(DOT_SYMBOL)) {
-
- Map sourceObject = ingestDocument.getSourceAndMetadata();
- ArrayList fieldValueList = JsonPath
- .using(suppressExceptionConfiguration)
- .parse(sourceObject)
- .read(documentFieldName);
- if (!fieldValueList.isEmpty()) {
- updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
- } else if (!ignoreMissing) {
- throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
+ // If the standard dot path fails, try to check for a nested array using JSON path
+ if (StringUtils.isValidJSONPath(documentFieldName)) {
+ Map sourceObject = ingestDocument.getSourceAndMetadata();
+ Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName);
+
+ if (fieldValue != null) {
+ if (fieldValue instanceof List) {
+ List> fieldValueList = (List>) fieldValue;
+ if (!fieldValueList.isEmpty()) {
+ updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
+ } else if (!ignoreMissing) {
+ throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
+ }
+ } else {
+ updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters);
}
} else if (!ignoreMissing) {
- throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
+ throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
+ } else {
+ throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
}
diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java
index 203392eb75..3ff5d957f3 100644
--- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java
@@ -7,6 +7,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
+import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME;
import java.io.IOException;
@@ -31,11 +32,13 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
+import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
+import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.script.ScriptService;
@@ -164,13 +167,26 @@ public void testExecute_nestedObjectStringDocumentSuccess() {
return null;
}).when(client).execute(any(), any(), any());
+ ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
processor.execute(nestedObjectIngestDocument, handler);
+
// match output documents
Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource();
sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, ImmutableMap.of("response", Arrays.asList(1, 2, 3)));
IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>());
verify(handler).accept(eq(ingestDocument1), isNull());
assertEquals(nestedObjectIngestDocument, ingestDocument1);
+
+ // match model input
+ verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any());
+ MLPredictionTaskRequest req = argCaptor.getValue();
+ MLInput mlInput = req.getMlInput();
+ RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
+ assertEquals(
+ toJson(inputDataSet.getParameters()),
+ "{\"inputs\":\"[{\\\"text\\\":[{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is first\\\"},{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is second\\\"}]},{\\\"text\\\":[{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is third\\\"},{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is fourth\\\"}]}]\"}"
+ );
+
}
/**
@@ -202,6 +218,23 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException {
return null;
}).when(client).execute(any(), any(), any());
+ /**
+ * Preview of sourceAndMetadata
+ * {
+ * "chunks": [
+ * {
+ * "chunk": {
+ * "text": "this is first"
+ * }
+ * },
+ * {
+ * "chunk": {
+ * "text": "this is second"
+ * }
+ * }
+ * ]
+ * }
+ */
ArrayList childDocuments = new ArrayList<>();
Map childDocument1Text = new HashMap<>();
childDocument1Text.put("text", "this is first");
@@ -219,6 +252,8 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException {
Map sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("chunks", childDocuments);
+ ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
+
IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
processor.execute(nestedObjectIngestDocument, handler);
@@ -250,6 +285,13 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException {
IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>());
verify(handler).accept(eq(ingestDocument1), isNull());
assertEquals(nestedObjectIngestDocument, ingestDocument1);
+
+ // match model input
+ verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any());
+ MLPredictionTaskRequest req = argCaptor.getValue();
+ MLInput mlInput = req.getMlInput();
+ RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
+ assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\"]\"}");
}
public void testExecute_jsonPathWithMissingLeaves() {
@@ -274,7 +316,7 @@ public void testExecute_jsonPathWithMissingLeaves() {
/**
* test nested object document with array of Map,
- * the value Object is a also a nested object,
+ * the value Object is also a nested object,
*/
public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() throws IOException {
List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context");
@@ -371,6 +413,7 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess(
Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource();
+ ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
processor.execute(nestedObjectIngestDocument, handler);
@@ -379,6 +422,17 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess(
assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.1.embedding", Object.class), Arrays.asList(2));
assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.0.embedding", Object.class), Arrays.asList(3));
assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.1.embedding", Object.class), Arrays.asList(4));
+
+ // match model input
+ verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any());
+ MLPredictionTaskRequest req = argCaptor.getValue();
+ MLInput mlInput = req.getMlInput();
+ RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
+ assertEquals(
+ toJson(inputDataSet.getParameters()),
+ "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\",\\\"this is third\\\",\\\"this is fourth\\\"]\"}"
+ );
+
}
public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingLeaveSuccess() {
@@ -955,6 +1009,8 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() {
return null;
}).when(client).execute(any(), any(), any());
+ ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
+
processor.execute(ingestDocument, handler);
Map sourceAndMetadata = new HashMap<>();
@@ -962,8 +1018,75 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() {
sourceAndMetadata.put("key2", "value2");
sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876"));
IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>());
+
+ // match output
verify(handler).accept(eq(ingestDocument1), isNull());
assertEquals(ingestDocument, ingestDocument1);
+
+ // match model input
+ verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any());
+ MLPredictionTaskRequest req = argCaptor.getValue();
+ MLInput mlInput = req.getMlInput();
+ RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
+ assertEquals(toJson(inputDataSet.getParameters()), "{\"key1\":\"value1\",\"key2\":\"value2\"}");
+ }
+
+ public void testExecute_InputMapAndOutputMapSuccess() {
+ List> outputMap = new ArrayList<>();
+ Map output = new HashMap<>();
+ output.put("classification", "response");
+ outputMap.add(output);
+
+ List> inputMap = new ArrayList<>();
+ Map input = new HashMap<>();
+ input.put("inputs", "key1");
+ inputMap.add(input);
+
+ MLInferenceIngestProcessor processor = createMLInferenceProcessor(
+ "model1",
+ inputMap,
+ outputMap,
+ null,
+ true,
+ "remote",
+ false,
+ false,
+ false,
+ null
+ );
+ ModelTensor modelTensor = ModelTensor
+ .builder()
+ .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876")))
+ .build();
+ ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
+ ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
+
+ doAnswer(invocation -> {
+ ActionListener actionListener = invocation.getArgument(2);
+ actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
+ return null;
+ }).when(client).execute(any(), any(), any());
+
+ ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
+
+ processor.execute(ingestDocument, handler);
+
+ Map sourceAndMetadata = new HashMap<>();
+ sourceAndMetadata.put("key1", "value1");
+ sourceAndMetadata.put("key2", "value2");
+ sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876"));
+ IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>());
+
+ // match output
+ verify(handler).accept(eq(ingestDocument1), isNull());
+ assertEquals(ingestDocument, ingestDocument1);
+
+ // match model input
+ verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any());
+ MLPredictionTaskRequest req = argCaptor.getValue();
+ MLInput mlInput = req.getMlInput();
+ RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
+ assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"value1\"}");
}
public void testExecute_getModelOutputFieldWithDotPathSuccess() {
@@ -1209,7 +1332,7 @@ public void testExecute_documentNotExistedFieldNameException() {
processor.execute(ingestDocument, handler);
verify(handler)
- .accept(eq(null), argThat(exception -> exception.getMessage().equals("cannot find field name defined from input map: key99")));
+ .accept(eq(null), argThat(exception -> exception.getMessage().equals("Cannot find field name defined from input map: key99")));
}
public void testExecute_nestedDocumentNotExistedFieldNameException() {
@@ -1235,7 +1358,7 @@ public void testExecute_nestedDocumentNotExistedFieldNameException() {
argThat(
exception -> exception
.getMessage()
- .equals("cannot find field name defined from input map: chunks.*.chunk.text.*.context1")
+ .equals("Cannot find field name defined from input map: chunks.*.chunk.text.*.context1")
)
);
}
@@ -1613,7 +1736,40 @@ public void testExecute_localModelSuccess() {
updatedBooks.add(updatedBook2);
sourceAndMetadata.put("books", updatedBooks);
- IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>());
+ // match meta data
+ Map expectedIngestMetadata = new HashMap<>();
+ Map valueMap = new HashMap<>();
+ Map titleEmbeddingMap = new HashMap<>();
+ List> inferenceResultsList = new ArrayList<>();
+
+ Map expectedOutputMap = new HashMap<>();
+ List> outputList = new ArrayList<>();
+ Map dataMap = new HashMap<>();
+ dataMap.put("data", Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ dataMap.put("name", "sentence_embedding");
+
+ Map inferenceResultMap = new HashMap<>();
+ List> outputListInner = new ArrayList<>();
+ outputListInner.add(dataMap);
+ inferenceResultMap.put("output", outputListInner);
+
+ Map dataAsMapMap = new HashMap<>();
+ List> inferenceResultsListInner = new ArrayList<>();
+ inferenceResultsListInner.add(inferenceResultMap);
+ dataAsMapMap.put("inference_results", inferenceResultsListInner);
+
+ Map expectedDataAsMap = new HashMap<>();
+ expectedDataAsMap.put("dataAsMap", dataAsMapMap);
+ outputList.add(expectedDataAsMap);
+ expectedOutputMap.put("output", outputList);
+ inferenceResultsList.add(expectedOutputMap);
+
+ titleEmbeddingMap.put("inference_results", inferenceResultsList);
+ valueMap.put("title_embedding", titleEmbeddingMap);
+ expectedIngestMetadata.put("_value", valueMap);
+
+ IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, expectedIngestMetadata);
+ System.out.println(ingestDocument1);
verify(handler).accept(eq(ingestDocument1), isNull());
assertEquals(nestedObjectIngestDocument, ingestDocument1);
}
@@ -1810,6 +1966,18 @@ public void testWriteNewDotPathForNestedObject() {
}
private static Map getNestedObjectWithAnotherNestedObjectSource() {
+ /**
+ * {chunks=[
+ * {chunk={text=[
+ * {context=this is first, chapter=first chapter},
+ * {context=this is second, chapter=first chapter}
+ * ]}},
+ * {chunk={text=[
+ * {context=this is third, chapter=second chapter},
+ * {context=this is fourth, chapter=second chapter}
+ * ]}}
+ * ]}
+ */
ArrayList childDocuments = new ArrayList<>();
Map childDocument1Text = new HashMap<>();
@@ -1836,7 +2004,7 @@ private static Map getNestedObjectWithAnotherNestedObjectSource(
grandChildDocument3Text.put("chapter", "second chapter");
Map grandChildDocument4Text = new HashMap<>();
grandChildDocument4Text.put("context", "this is fourth");
- grandChildDocument4Text.put("chapter", "first chapter");
+ grandChildDocument4Text.put("chapter", "second chapter");
grandChildDocuments2.add(grandChildDocument3Text);
grandChildDocuments2.add(grandChildDocument4Text);