Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Enable pass query string to input_map in ml inference search response processor #3129

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies {
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly group: 'org.json', name: 'json', version: '20231013'

implementation('com.google.guava:guava:32.1.2-jre') {
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
Expand Down Expand Up @@ -457,4 +458,53 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() {
String result = getJsonPath(input);
assertEquals("$.response.body.data[*].embedding", result);
}

@Test
public void testisValidJSONPath_InvalidInputs() {
Assert.assertFalse(isValidJSONPath("..bar"));
Assert.assertFalse(isValidJSONPath("."));
Assert.assertFalse(isValidJSONPath(".."));
Assert.assertFalse(isValidJSONPath("foo.bar."));
Assert.assertFalse(isValidJSONPath(".foo.bar."));
}

@Test
public void testisValidJSONPath_NullInput() {
Assert.assertFalse(isValidJSONPath(null));
}

@Test
public void testisValidJSONPath_EmptyInput() {
Assert.assertFalse(isValidJSONPath(""));
}

@Test
public void testisValidJSONPath_ValidInputs() {
Assert.assertTrue(isValidJSONPath("foo"));
Assert.assertTrue(isValidJSONPath("foo.bar"));
Assert.assertTrue(isValidJSONPath("foo.bar.baz"));
Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux"));
Assert.assertTrue(isValidJSONPath(".foo"));
Assert.assertTrue(isValidJSONPath("$.foo"));
Assert.assertTrue(isValidJSONPath(".foo.bar"));
Assert.assertTrue(isValidJSONPath("$.foo.bar"));
}

@Test
public void testisValidJSONPath_WithFilter() {
Assert.assertTrue(isValidJSONPath("$.store['book']"));
Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']"));
Assert.assertTrue(isValidJSONPath("$.store.book[0]"));
Assert.assertTrue(isValidJSONPath("$.store.book[1,2]"));
Assert.assertTrue(isValidJSONPath("$.store.book[-1:] "));
Assert.assertTrue(isValidJSONPath("$.store.book[0:2]"));
Assert.assertTrue(isValidJSONPath("$.store.book[*]"));
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]"));
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]"));
Assert.assertTrue(isValidJSONPath("$..author"));
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]"));
Assert.assertTrue(isValidJSONPath("$.store.book[0,1]"));
Assert.assertTrue(isValidJSONPath("$['store','warehouse']"));
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.processor;

import static java.lang.Math.max;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
Expand Down Expand Up @@ -55,12 +56,11 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;

public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor, ModelExecutor {

public static final String REQUEST_PREFIX = "_request.";
private final NamedXContentRegistry xContentRegistry;
private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class);
private final InferenceProcessorAttributes inferenceProcessorAttributes;
Expand Down Expand Up @@ -155,6 +155,8 @@ public void processResponseAsync(
try {
SearchHit[] hits = response.getHits().getHits();
// skip processing when there is no hit

String queryString = request.source().toString();
if (hits.length == 0) {
responseListener.onResponse(response);
return;
Expand Down Expand Up @@ -183,7 +185,7 @@ public void processResponseAsync(
);
}

rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
rewriteResponseDocuments(mlInferenceSearchResponse, responseListener, queryString);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
Expand All @@ -198,7 +200,7 @@ public void processResponseAsync(
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
rewriteResponseDocuments(oneHitResponse, oneHitListener, queryString);
// if any OneHitListener failure, try stop the rest of the predictions
if (isOneHitListenerFailed.get()) {
break;
Expand Down Expand Up @@ -305,9 +307,11 @@ public void onFailure(Exception e) {
*
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
* @throws IOException if an I/O error occurs during the rewriting process
*/
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener) throws IOException {
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener, String queryString)
throws IOException {
List<Map<String, String>> processInputMap = inferenceProcessorAttributes.getInputMaps();
List<Map<String, String>> processOutputMap = inferenceProcessorAttributes.getOutputMaps();
int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size();
Expand All @@ -329,7 +333,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions, queryString);
}
}

Expand All @@ -341,56 +345,80 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param inputMapIndex the index of the input mapping to process
* @param batchPredictionListener the listener to be notified when the predictions are processed
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
* @throws IOException if an I/O error occurs during the prediction process
*/
private void processPredictions(
SearchHit[] hits,
List<Map<String, String>> processInputMap,
int inputMapIndex,
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener,
Map<Integer, Integer> hitCountInPredictions
Map<Integer, Integer> hitCountInPredictions,
String queryString
) throws IOException {

Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
Map<String, String> modelConfigMapsInput = inferenceProcessorAttributes.getModelConfigMaps();

modelParameters.putAll(modelConfigMapsInput);
modelConfigs.putAll(modelConfigMapsInput);

}

Map<String, Object> modelInputParameters = new HashMap<>();

Map<String, String> inputMapping;
if (processInputMap != null && !processInputMap.isEmpty()) {
inputMapping = processInputMap.get(inputMapIndex);
boolean isRequestInputMissing = checkIsRequestInputMissing(queryString, inputMapping);
if (isRequestInputMissing) {
if (!ignoreMissing) {
throw new IllegalArgumentException(
"Missing required input field in query body. input_map: " + inputMapping.values() + ", query body:" + queryString
);
}
}

for (SearchHit hit : hits) {
Map<String, Object> document = hit.getSourceAsMap();
boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
if (!isModelInputMissing) {
boolean isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
if (!isDocumentFieldMissing) {
MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex);
for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
// model field as key, document field name as value
String modelInputFieldName = entry.getKey();
String documentFieldName = entry.getValue();

Object documentJson = JsonPath.parse(document).read("$");
Configuration configuration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL)
.build();

Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
// read the query string when the mapping field name starts with "$._request." or "_request."
// skip when modelInputParameters already has this modelInputFieldName to avoid duplicate read
if (StringUtils.isValidJSONPath(documentFieldName)
&& (documentFieldName.startsWith("$." + REQUEST_PREFIX) || documentFieldName.startsWith(REQUEST_PREFIX))
&& !modelInputParameters.containsKey(modelInputFieldName)) {
String requestFieldName = documentFieldName.replaceFirst(REQUEST_PREFIX, "");

Object queryText = JsonPath.using(suppressExceptionConfiguration).parse(queryString).read(requestFieldName);
if (queryText != null) {
modelInputParameters.put(modelInputFieldName, toJson(queryText));
}
} else {
Object documentValue = JsonPath.using(suppressExceptionConfiguration).parse(document).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
}
}
}
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
if (!ignoreMissing) {
throw new IllegalArgumentException(
"cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit
"cannot find all required input fields: "
+ inputMapping.values()
+ " in hit:"
+ hit
+ " and query body:"
+ queryString
);
}
}
Expand Down Expand Up @@ -542,11 +570,11 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
Map<String, String> inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap);
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);

boolean isModelInputMissing = false;
boolean isDocumentFieldMissing = false;
if (processInputMap != null && !processInputMap.isEmpty()) {
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
}
if (!isModelInputMissing) {
if (!isDocumentFieldMissing) {
// Iterate over outputMapping
for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {

Expand Down Expand Up @@ -637,22 +665,45 @@ public void onFailure(Exception e) {

/**
* Checks if the document is missing any of the required input fields specified in the input mapping.
* When model config contains the default model_input value, it's not considered as missing model input.
*
* @param document the document map
* @param inputMapping the input mapping
* @return true if the document is missing any of the required input fields, false otherwise
*/
private boolean checkIsModelInputMissing(Map<String, Object> document, Map<String, String> inputMapping) {
boolean isModelInputMissing = false;
for (Map.Entry<String, String> inputMapEntry : inputMapping.entrySet()) {
String oldDocumentFieldName = inputMapEntry.getValue();
boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName);
if (!checkSingleModelInputPresent) {
isModelInputMissing = true;
break;
}
}
return isModelInputMissing;
private boolean checkIsDocumentFieldMissing(Map<String, Object> document, Map<String, String> inputMapping) {
return inputMapping
.values()
.stream()
.filter(fieldName -> !(fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX)))
.anyMatch(fieldName -> {
boolean isFieldPresentInDocument = document != null && hasField(document, fieldName);
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(fieldName);
return !isFieldPresentInDocument && !isFieldPresentInModelConfig;
});
}

/**
* Checks if the request is missing any of the required input fields specified in the input mapping.
* When model config contains the default model_input value, it's not considered as missing model input.
*
* @param queryString the query body in string format, e.g., "{ \"query\": { \"match_all\": {} } }\n"
* @param inputMapping the input mapping
* @return true if the document is missing any of the required input fields, false otherwise
*/
private boolean checkIsRequestInputMissing(String queryString, Map<String, String> inputMapping) {
return inputMapping
.values()
.stream()
.filter(fieldName -> fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX))
.map(fieldName -> fieldName.replaceFirst(REQUEST_PREFIX, ""))
.anyMatch(requestFieldName -> {
boolean isFieldPresentInQuery = queryString != null && hasField(queryString, requestFieldName);
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(requestFieldName);
return !isFieldPresentInQuery && !isFieldPresentInModelConfig;
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,12 @@ default String toString(Object originalFieldValue) {
}

default boolean hasField(Object json, String path) {
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);

Object value;
if (json instanceof String) {
value = JsonPath.using(suppressExceptionConfiguration).parse((String) json).read(path);
} else {
value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);
}
if (value != null) {
return true;
}
Expand Down
Loading
Loading