From fcbb7f19efcc59b6c93c6c8f499ac85d886f0538 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 7 Aug 2023 22:25:15 +0800 Subject: [PATCH 1/5] Fix breaking change caused by opensearch core Signed-off-by: zane-neo --- .../java/org/opensearch/ml/common/utils/StringUtils.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 968cda1575..34631b95b4 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -11,6 +11,8 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -82,4 +84,8 @@ public static Map getParameterMap(Map parameterObjs) } return parameters; } + + public static String xContentBuilderToString(XContentBuilder builder) { + return BytesReference.bytes(builder).utf8ToString(); + } } From 95572834a33bc8db076d4df42d47663781714c47 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 21 Aug 2023 16:53:32 +0800 Subject: [PATCH 2/5] Add neural search default pre/post process function support Signed-off-by: zane-neo --- common/build.gradle | 2 +- .../ml/common/connector/Connector.java | 6 +- .../connector/MLPostProcessFunction.java | 85 +++++------ .../connector/MLPreProcessFunction.java | 41 +++--- .../ml/common/output/model/ModelTensor.java | 6 +- .../opensearch/ml/common/utils/GsonUtil.java | 30 ++++ .../ml/common/utils/StringUtils.java | 12 +- .../connector/MLPostProcessFunctionTest.java | 29 ++++ .../ml/common/utils/GsonUtilsTest.java | 47 ++++++ .../org/opensearch/ml/engine/ModelHelper.java | 8 +- .../algorithms/remote/ConnectorUtils.java | 135 ++++++++++-------- .../remote/RemoteConnectorExecutor.java | 12 +- .../ml/engine/utils/ScriptUtils.java | 36 ++--- .../remote/AwsConnectorExecutorTest.java | 34 +++++ .../algorithms/remote/ConnectorUtilsTest.java | 48 +++++-- .../ml/engine/utils/ScriptUtilsTest.java | 59 ++++++++ 16 files changed, 400 insertions(+), 190 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java create mode 100644 common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java diff --git a/common/build.gradle b/common/build.gradle index 3aeb644d42..01f548c140 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -19,7 +19,7 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' - compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly group: 'org.json', name: 'json', version: '20230227' } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index e963cb4dfa..06f476c809 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -5,8 +5,6 @@ package org.opensearch.ml.common.connector; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.security.AccessController; @@ -32,6 +30,8 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.utils.GsonUtil; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; /** * Connector defines how to connect to a remote service. @@ -108,7 +108,7 @@ static Connector createConnector(XContentParser parser) throws IOException { Map connectorMap = parser.map(); String jsonStr; try { - jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(connectorMap)); + jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> GsonUtil.toJson(connectorMap)); } catch (PrivilegedActionException e) { throw new IllegalArgumentException("wrong connector"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 662db37341..663ee6b031 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,61 +5,64 @@ package org.opensearch.ml.common.connector; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPostProcessFunction { - private static Map POST_PROCESS_FUNCTIONS; public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; + public static final String NEURAL_SEARCH_EMBEDDING = "connector.post_process.neural_search.text_embedding"; + + private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); + + private static final Map>, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); + + static { - POST_PROCESS_FUNCTIONS = new HashMap<>(); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, "\n def name = \"sentence_embedding\";\n" + - " def dataType = \"FLOAT32\";\n" + - " if (params.embeddings == null || params.embeddings.length == 0) {\n" + - " return null;\n" + - " }\n" + - " def embeddings = params.embeddings;\n" + - " StringBuilder builder = new StringBuilder(\"[\");\n" + - " for (int i=0; i>, List> buildModelTensorList() { + return numbersList -> { + List modelTensors = new ArrayList<>(); + if (numbersList == null) { + throw new IllegalArgumentException("NumbersList is null when applying build-in post process function!"); + } + numbersList.forEach(numbers -> modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{numbers.size()}) + .data(numbers.toArray(new Number[0])) + .build() + )); + return modelTensors; + }; } - public static boolean contains(String functionName) { - return POST_PROCESS_FUNCTIONS.containsKey(functionName); + public static String getResponseFilter(String postProcessFunction) { + return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static String get(String postProcessFunction) { + public static Function>, List> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } + + public static boolean contains(String postProcessFunction) { + return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index b49e075aea..23a575e860 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -6,44 +6,37 @@ package org.opensearch.ml.common.connector; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPreProcessFunction { - private static Map PRE_PROCESS_FUNCTIONS; + private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; + public static final String NEURAL_SEARCH_EMBEDDING_INPUT = "connector.pre_process.neural_search.text_embedding"; + + private static Function, Map> cohereTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("texts", inputs)); + } + + private static Function, Map> openAiTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("input", inputs)); + } + static { - PRE_PROCESS_FUNCTIONS = new HashMap<>(); - //TODO: change to java for openAI, embedding and Titan - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"[\");\n" + - " for (int i=0; i< params.text_docs.length; i++) {\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[i]);\n" + - " builder.append(\"\\\"\");\n" + - " if (i < params.text_docs.length - 1) {\n" + - " builder.append(\",\")\n" + - " }\n" + - " }\n" + - " builder.append(\"]\");\n" + - " def parameters = \"{\" +\"\\\"texts\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); - - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[0]);\n" + - " builder.append(\"\\\"\");\n" + - " def parameters = \"{\" +\"\\\"input\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static String get(String postProcessFunction) { + public static Function, Map> get(String postProcessFunction) { return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index 8957f33643..c51031ac3d 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -6,7 +6,6 @@ package org.opensearch.ml.common.output.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.nio.ByteBuffer; @@ -26,6 +25,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.utils.GsonUtil; @Data public class ModelTensor implements Writeable, ToXContentObject { @@ -224,7 +224,7 @@ public ModelTensor(StreamInput in) throws IOException { this.result = in.readOptionalString(); if (in.readBoolean()) { String mapStr = in.readString(); - this.dataAsMap = gson.fromJson(mapStr, Map.class); + this.dataAsMap = GsonUtil.fromJson(mapStr, Map.class); } } @@ -270,7 +270,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(true); try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - out.writeString(gson.toJson(dataAsMap)); + out.writeString(GsonUtil.toJson(dataAsMap)); return null; }); } catch (PrivilegedActionException e) { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java b/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java new file mode 100644 index 0000000000..b261d44124 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java @@ -0,0 +1,30 @@ +package org.opensearch.ml.common.utils; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.stream.JsonReader; + +public class GsonUtil { + + private static final Gson gson; + + static { + gson = new Gson(); + } + + public static String toJson(Object obj) { + return gson.toJson(obj); + } + + public static T fromJson(String json, Class clazz) { + return gson.fromJson(json, clazz); + } + + public static T fromJson(JsonElement jsonElement, Class clazz) { + return gson.fromJson(jsonElement, clazz); + } + + public static T fromJson(JsonReader jsonReader, Class clazz) { + return gson.fromJson(jsonReader, clazz); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 34631b95b4..479adf7a79 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,7 +5,6 @@ package org.opensearch.ml.common.utils; -import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; import org.json.JSONArray; @@ -25,11 +24,6 @@ public class StringUtils { - public static final Gson gson; - static { - gson = new Gson(); - } - public static boolean isJson(String Json) { try { new JSONObject(Json); @@ -54,9 +48,9 @@ public static Map fromJson(String jsonStr, String defaultKey) { Map result; JsonElement jsonElement = JsonParser.parseString(jsonStr); if (jsonElement.isJsonObject()) { - result = gson.fromJson(jsonElement, Map.class); + result = GsonUtil.fromJson(jsonElement, Map.class); } else if (jsonElement.isJsonArray()) { - List list = gson.fromJson(jsonElement, List.class); + List list = GsonUtil.fromJson(jsonElement, List.class); result = new HashMap<>(); result.put(defaultKey, list); } else { @@ -74,7 +68,7 @@ public static Map getParameterMap(Map parameterObjs) if (value instanceof String) { parameters.put(key, (String)value); } else { - parameters.put(key, gson.toJson(value)); + parameters.put(key, GsonUtil.toJson(value)); } return null; }); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 346d5901a8..5d4c0c88d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -6,12 +6,21 @@ package org.opensearch.ml.common.connector; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; public class MLPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Test public void contains() { Assert.assertTrue(MLPostProcessFunction.contains(OPENAI_EMBEDDING)); @@ -23,4 +32,24 @@ public void get() { Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING)); Assert.assertNull(MLPostProcessFunction.get("wrong value")); } + + @Test + public void test_getResponseFilter() { + assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING); + assert null == MLPostProcessFunction.getResponseFilter("wrong value"); + } + + @Test + public void test_buildModelTensorList() { + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); + List> numbersList = new ArrayList<>(); + numbersList.add(Collections.singletonList(1.0f)); + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + } + + @Test + public void test_buildModelTensorList_exception() { + exceptionRule.expect(IllegalArgumentException.class); + MLPostProcessFunction.buildModelTensorList().apply(null); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java new file mode 100644 index 0000000000..c5ceb58917 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java @@ -0,0 +1,47 @@ +package org.opensearch.ml.common.utils; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.google.gson.stream.JsonReader; +import org.junit.Assert; +import org.junit.Test; + +import java.io.StringReader; +import java.util.HashMap; +import java.util.Map; + +public class GsonUtilsTest { + @Test + public void test_toJson() { + Map map = new HashMap<>(); + map.put("key", "value"); + String mapString = GsonUtil.toJson(map); + assert mapString.equals("{\"key\":\"value\"}"); + } + + @Test + public void test_fromJsonString() { + Map map = GsonUtil.fromJson("{\"key\": \"value\"}", Map.class); + Assert.assertEquals(1, map.size()); + Assert.assertEquals("value", map.get("key")); + } + + @Test + public void test_fromJsonJsonElement() { + JsonElement jsonElement = JsonParser.parseString("{\"key\": \"value\"}"); + Map map = GsonUtil.fromJson(jsonElement, Map.class); + Assert.assertEquals(1, map.size()); + Assert.assertEquals("value", map.get("key")); + } + + @Test + public void test_fromJsonJsonReader() { + JsonReader reader = new JsonReader(new StringReader("{\"key\": \"value\"}")); + Map map = GsonUtil.fromJson(reader, Map.class); + Assert.assertEquals(1, map.size()); + Assert.assertEquals("value", map.get("key")); + } + + + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 3dcb210dc1..2582c8a6df 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -7,7 +7,6 @@ import ai.djl.training.util.DownloadUtils; import ai.djl.training.util.ProgressBar; -import com.google.gson.Gson; import com.google.gson.stream.JsonReader; import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; @@ -16,6 +15,7 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.utils.GsonUtil; import java.io.File; import java.io.FileReader; @@ -48,11 +48,9 @@ public class ModelHelper { public static final String PYTORCH_ENGINE = "PyTorch"; public static final String ONNX_ENGINE = "OnnxRuntime"; private final MLEngine mlEngine; - private Gson gson; public ModelHelper(MLEngine mlEngine) { this.mlEngine = mlEngine; - gson = new Gson(); } public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener listener) { @@ -74,7 +72,7 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi Map config = null; try (JsonReader reader = new JsonReader(new FileReader(configCacheFilePath))) { - config = gson.fromJson(reader, Map.class); + config = GsonUtil.fromJson(reader, Map.class); } if (config == null) { @@ -172,7 +170,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re List config = null; try (JsonReader reader = new JsonReader(new FileReader(cacheFilePath))) { - config = gson.fromJson(reader, List.class); + config = GsonUtil.fromJson(reader, List.class); } return config; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cd3038f49c..7b307d98fa 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -5,17 +5,20 @@ package org.opensearch.ml.engine.algorithms.remote; -import com.google.common.collect.ImmutableMap; import com.jayway.jsonpath.JsonPath; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -37,10 +40,11 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; -import static org.opensearch.ml.engine.utils.ScriptUtils.executePostprocessFunction; +import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; +import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; -import static org.opensearch.ml.engine.utils.ScriptUtils.gson; +@Log4j2 public class ConnectorUtils { private static final Aws4Signer signer; @@ -54,43 +58,7 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } RemoteInferenceInputDataSet inputData; if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet inputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset(); - List docs = new ArrayList<>(inputDataSet.getDocs()); - Map params = ImmutableMap.of("text_docs", docs); - Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { - throw new IllegalArgumentException("no predict action found"); - } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - if (preProcessFunction == null) { - throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); - } - if (preProcessFunction != null && preProcessFunction.contains("${parameters")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - preProcessFunction = substitutor.replace(preProcessFunction); - } - Optional processedResponse = executePreprocessFunction(scriptService, preProcessFunction, params); - if (!processedResponse.isPresent()) { - throw new IllegalArgumentException("Wrong input"); - } - Map map = gson.fromJson(processedResponse.get(), Map.class); - Map parametersMap = (Map) map.get("parameters"); - Map processedParameters = new HashMap<>(); - for (String key : parametersMap.keySet()) { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (parametersMap.get(key) instanceof String) { - processedParameters.put(key, (String) parametersMap.get(key)); - } else { - processedParameters.put(key, gson.toJson(parametersMap.get(key))); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - inputData = RemoteInferenceInputDataSet.builder().parameters(processedParameters).build(); + inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); } else { @@ -98,20 +66,65 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); - inputData.getParameters().entrySet().forEach(entry -> { - if (entry.getValue() == null) { - newParameters.put(entry.getKey(), entry.getValue()); - } else if (StringUtils.isJson(entry.getValue())) { + inputData.getParameters().forEach((key, value) -> { + if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { // no need to escape if it's already valid json - newParameters.put(entry.getKey(), entry.getValue()); + newParameters.put(key, value); } else { - newParameters.put(entry.getKey(), escapeJson(entry.getValue())); + newParameters.put(key, escapeJson(value)); } }); inputData.setParameters(newParameters); } return inputData; } + private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { + List docs = new ArrayList<>(inputDataSet.getDocs()); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); + } + String preProcessFunction = predictAction.get().getPreProcessFunction(); + if (preProcessFunction == null) { + throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); + } + if (MLPreProcessFunction.contains(preProcessFunction)) { + Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); + } else { + if (preProcessFunction.contains("${parameters")) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + preProcessFunction = substitutor.replace(preProcessFunction); + } + Optional processedInput = executePreprocessFunction(scriptService, preProcessFunction, docs); + if (processedInput.isEmpty()) { + throw new IllegalArgumentException("Wrong input"); + } + Map map = GsonUtil.fromJson(processedInput.get(), Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } + } + + private static Map convertScriptStringToJsonString(Map processedInput) { + Map parameterStringMap = new HashMap<>(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map parametersMap = (Map) processedInput.get("parameters"); + for (String key : parametersMap.keySet()) { + if (parametersMap.get(key) instanceof String) { + parameterStringMap.put(key, (String) parametersMap.get(key)); + } else { + parameterStringMap.put(key, GsonUtil.toJson(parametersMap.get(key))); + } + } + return null; + }); + } catch (PrivilegedActionException e) { + log.error("Error processing parameters", e); + throw new RuntimeException(e); + } + return parameterStringMap; + } public static ModelTensors processOutput(String modelResponse, Connector connector, ScriptService scriptService, Map parameters) throws IOException { if (modelResponse == null) { @@ -119,26 +132,36 @@ public static ModelTensors processOutput(String modelResponse, Connector connect } List modelTensors = new ArrayList<>(); Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { + if (predictAction.isEmpty()) { throw new IllegalArgumentException("no predict action found"); } - String postProcessFunction = predictAction.get().getPostProcessFunction(); + ConnectorAction connectorAction = predictAction.get(); + String postProcessFunction = connectorAction.getPostProcessFunction(); if (postProcessFunction != null && postProcessFunction.contains("${parameters")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); postProcessFunction = substitutor.replace(postProcessFunction); } - Optional processedResponse = executePostprocessFunction(scriptService, postProcessFunction, modelResponse); + String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); + if (MLPostProcessFunction.contains(postProcessFunction)) { + // in this case, we can use jsonpath to build a List> result from model response. + if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); + List> vectors = JsonPath.read(modelResponse, responseFilter); + List processedResponse = executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get(postProcessFunction)); + return ModelTensors.builder().mlModelTensors(processedResponse).build(); + } + + // execute user defined painless script. + Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - if (parameters.get(RESPONSE_FILTER_FIELD) == null) { - connector.parseResponse(response, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + if (responseFilter == null) { + connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD)); - connector.parseResponse(filteredResponse, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); } - - ModelTensors tensors = ModelTensors.builder().mlModelTensors(modelTensors).build(); - return tensors; + return ModelTensors.builder().mlModelTensors(modelTensors).build(); } public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c9b6e78873..8712f771c7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,14 +32,8 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList(textDocsInputDataSet.getDocs()); - for (int i = 0; i < textDocsInputDataSet.getDocs().size(); i++) { - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); - if (tensorOutputs.size() >= textDocsInputDataSet.getDocs().size()) { - break; - } - textDocs.remove(0); - } + List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } @@ -65,7 +59,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List executePreprocessFunction(ScriptService scriptService, String preProcessFunction, List inputSentences) { + return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static Optional executePreprocessFunction(ScriptService scriptService, - String preProcessFunction, - Map params) { - if (MLPreProcessFunction.contains(preProcessFunction)) { - preProcessFunction = MLPreProcessFunction.get(preProcessFunction); - } - if (preProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, preProcessFunction, params)); - } - return Optional.empty(); + public static List executeBuildInPostProcessFunction(List> vectors, Function>, List> function) { + return function.apply(vectors); } - public static Optional executePostprocessFunction(ScriptService scriptService, - String postProcessFunction, - String resultJson) { - Map result = StringUtils.fromJson(resultJson, "result"); - if (MLPostProcessFunction.contains(postProcessFunction)) { - postProcessFunction = MLPostProcessFunction.get(postProcessFunction); - } + public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { + Map result = org.opensearch.ml.common.utils.StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } return Optional.empty(); } - public static String executeScript(ScriptService scriptService, String painlessScript, Map params) { Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ecc143ea6f..5dbbf2090e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Assert; import org.junit.Before; @@ -18,7 +19,9 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -136,4 +139,35 @@ public void executePredict_RemoteInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void executePredict_TextDocsInferenceInput() throws IOException { + String jsonString = "{\"key\":\"value\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 2a84e2fee1..86e1568137 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -24,11 +24,15 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -121,18 +125,20 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { + List input = Collections.singletonList("test_value"); + String inputJson = GsonUtil.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": \"${parameters.input}\"}", - "{\"parameters\": { \"input\": \"test_value\" } }", - "test_value"); + "{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts"); } @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() { + List input = new ArrayList<>(); + input.add("test_value1"); + input.add("test_value2"); + String inputJson = GsonUtil.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": ${parameters.input}}", - "{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }", - "[\"test_value1\",\"test_value2\"]"); + "{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input"); } @Test @@ -143,7 +149,7 @@ public void processOutput_NullResponse() throws IOException { } @Test - public void processOutput_NoPostprocessFunction() throws IOException { + public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") @@ -160,6 +166,24 @@ public void processOutput_NoPostprocessFunction() throws IOException { Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); } + @Test + public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOException { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map parameters = new HashMap<>(); + parameters.put("key1", "value1"); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, ImmutableMap.of()); + Assert.assertEquals(1, tensors.getMlModelTensors().size()); + Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); + } + @Test public void processOutput_PostprocessFunction() throws IOException { String postprocessResult = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[1536],\"data\":[-0.014555434, -2.135904E-4, 0.0035105038]}"; @@ -186,10 +210,8 @@ public void processOutput_PostprocessFunction() throws IOException { Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); } - private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, String preprocessResult, String expectedProcessedInput) { - when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); - - TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); + private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, List inputs, String expectedProcessedInput, String preProcessName, String resultKey) { + TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(inputs).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); ConnectorAction predictAction = ConnectorAction.builder() @@ -197,7 +219,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request .method("POST") .url("http://test.com/mock") .requestBody(requestBody) - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT) + .preProcessFunction(preProcessName) .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); @@ -205,6 +227,6 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); - Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get("input")); + Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java new file mode 100644 index 0000000000..24daf5ea8c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -0,0 +1,59 @@ +package org.opensearch.ml.engine.utils; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.script.ScriptService; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class ScriptUtilsTest { + + @Mock + ScriptService scriptService; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("test result")); + } + + @Test + public void test_executePreprocessFunction() { + Optional resultOpt = ScriptUtils.executePreprocessFunction(scriptService, "any function", Collections.singletonList("any input")); + assertEquals("test result", resultOpt.get()); + } + + @Test + public void test_executeBuildInPostProcessFunction() { + List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); + List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.NEURAL_SEARCH_EMBEDDING)); + assertNotNull(modelTensors); + assertEquals(2, modelTensors.size()); + } + + @Test + public void test_executePostProcessFunction() { + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"test result\"}")); + Optional resultOpt = ScriptUtils.executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}"); + assertEquals("{\"result\": \"test result\"}", resultOpt.get()); + } + + @Test + public void test_executeScript() { + String result = ScriptUtils.executeScript(scriptService, "any function", Collections.singletonMap("key", "value")); + assertEquals("test result", result); + } +} From 01c4ef792d9513cacfa45bd0a815e462f6596e53 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 22 Aug 2023 12:42:32 +0800 Subject: [PATCH 3/5] Fix UT failures Signed-off-by: zane-neo --- .../ml/common/utils/StringUtils.java | 8 +++++-- .../algorithms/remote/ConnectorUtils.java | 4 +++- .../algorithms/remote/ConnectorUtilsTest.java | 4 ++-- .../remote/HttpJsonConnectorExecutorTest.java | 23 +++++++++++-------- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 479adf7a79..5b35000807 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -18,6 +18,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,9 +45,12 @@ public static String toUTF8(String rawString) { return utf8EncodedString; } - public static Map fromJson(String jsonStr, String defaultKey) { + public static Map fromJson(String input, String defaultKey) { + if (!isJson(input)) { + return Collections.singletonMap(defaultKey, input); + } Map result; - JsonElement jsonElement = JsonParser.parseString(jsonStr); + JsonElement jsonElement = JsonParser.parseString(input); if (jsonElement.isJsonObject()) { result = GsonUtil.fromJson(jsonElement, Map.class); } else if (jsonElement.isJsonArray()) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 7b307d98fa..e8795a2e79 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -67,7 +67,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().forEach((key, value) -> { - if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { // no need to escape if it's already valid json newParameters.put(key, value); } else { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 86e1568137..98e90e668f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -175,9 +175,9 @@ public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOExcep .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map parameters = new HashMap<>(); - parameters.put("key1", "value1"); + parameters.put("input", "value1"); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); - ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8d04603d2a..a8290f0692 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -112,14 +112,10 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() { @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String postprocessResult1 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[1, 2, 3]}"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - String postprocessResult2 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[4, 5, 6]}"; when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -127,21 +123,28 @@ public void executePredict_TextDocsInput() throws IOException { .url("http://test.com/mock") .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {1, 2, 3}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); - Assert.assertArrayEquals(new Number[] {4, 5, 6}, modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } } From cb830222d1ba425fce3029b366a51aa37291a469 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 26 Sep 2023 09:14:05 +0800 Subject: [PATCH 4/5] Fix conflicts when backport Signed-off-by: zane-neo --- .../ml/common/connector/Connector.java | 16 +++++++++-- .../connector/MLPostProcessFunction.java | 18 ++++++------- .../connector/MLPreProcessFunction.java | 4 +-- .../input/remote/RemoteInferenceMLInput.java | 9 ++++--- .../ml/common/utils/StringUtils.java | 27 +++++++++---------- .../org/opensearch/ml/engine/ModelHelper.java | 1 + .../algorithms/remote/ConnectorUtils.java | 10 +++---- .../ml/engine/utils/ScriptUtils.java | 3 ++- .../algorithms/remote/ConnectorUtilsTest.java | 26 +++--------------- .../remote/HttpJsonConnectorExecutorTest.java | 21 ++++++++++----- .../ml/engine/utils/ScriptUtilsTest.java | 2 +- 11 files changed, 68 insertions(+), 69 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 06f476c809..419e460c95 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -30,8 +30,20 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; -import org.opensearch.ml.common.utils.GsonUtil; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; /** * Connector defines how to connect to a remote service. @@ -108,7 +120,7 @@ static Connector createConnector(XContentParser parser) throws IOException { Map connectorMap = parser.map(); String jsonStr; try { - jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> GsonUtil.toJson(connectorMap)); + jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(connectorMap)); } catch (PrivilegedActionException e) { throw new IllegalArgumentException("wrong connector"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 663ee6b031..9d9ba90171 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -19,7 +19,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; - public static final String NEURAL_SEARCH_EMBEDDING = "connector.post_process.neural_search.text_embedding"; + public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); @@ -29,25 +29,25 @@ public class MLPostProcessFunction { static { JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); - JSON_PATH_EXPRESSION.put(NEURAL_SEARCH_EMBEDDING, "$[*]"); + JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList()); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING, buildModelTensorList()); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList()); } public static Function>, List> buildModelTensorList() { - return numbersList -> { + return embeddings -> { List modelTensors = new ArrayList<>(); - if (numbersList == null) { - throw new IllegalArgumentException("NumbersList is null when applying build-in post process function!"); + if (embeddings == null) { + throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); } - numbersList.forEach(numbers -> modelTensors.add( + embeddings.forEach(embedding -> modelTensors.add( ModelTensor .builder() .name("sentence_embedding") .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{numbers.size()}) - .data(numbers.toArray(new Number[0])) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) .build() )); return modelTensors; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 23a575e860..0a41e17a9b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -16,7 +16,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; - public static final String NEURAL_SEARCH_EMBEDDING_INPUT = "connector.pre_process.neural_search.text_embedding"; + public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; private static Function, Map> cohereTextEmbeddingPreProcess() { return inputs -> Map.of("parameters", Map.of("texts", inputs)); @@ -29,7 +29,7 @@ private static Function, Map> openAiTextEmbeddingPr static { PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); } public static boolean contains(String functionName) { diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index 412e7a8e7e..da4a9ad73d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -5,10 +5,6 @@ package org.opensearch.ml.common.input.remote; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; -import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; @@ -17,6 +13,11 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 5b35000807..edbd94b37f 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,26 +5,30 @@ package org.opensearch.ml.common.utils; +import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.xcontent.XContentBuilder; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; public class StringUtils { + public static final Gson gson; + + static { + gson = new Gson(); + } + public static boolean isJson(String Json) { try { new JSONObject(Json); @@ -45,16 +49,13 @@ public static String toUTF8(String rawString) { return utf8EncodedString; } - public static Map fromJson(String input, String defaultKey) { - if (!isJson(input)) { - return Collections.singletonMap(defaultKey, input); - } + public static Map fromJson(String jsonStr, String defaultKey) { Map result; - JsonElement jsonElement = JsonParser.parseString(input); + JsonElement jsonElement = JsonParser.parseString(jsonStr); if (jsonElement.isJsonObject()) { - result = GsonUtil.fromJson(jsonElement, Map.class); + result = gson.fromJson(jsonElement, Map.class); } else if (jsonElement.isJsonArray()) { - List list = GsonUtil.fromJson(jsonElement, List.class); + List list = gson.fromJson(jsonElement, List.class); result = new HashMap<>(); result.put(defaultKey, list); } else { @@ -72,7 +73,7 @@ public static Map getParameterMap(Map parameterObjs) if (value instanceof String) { parameters.put(key, (String)value); } else { - parameters.put(key, GsonUtil.toJson(value)); + parameters.put(key, gson.toJson(value)); } return null; }); @@ -82,8 +83,4 @@ public static Map getParameterMap(Map parameterObjs) } return parameters; } - - public static String xContentBuilderToString(XContentBuilder builder) { - return BytesReference.bytes(builder).utf8ToString(); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 2582c8a6df..ab32c1580b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -32,6 +32,7 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index e8795a2e79..ac3f8a7eda 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -18,7 +18,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -40,6 +39,7 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; @@ -87,9 +87,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat throw new IllegalArgumentException("no predict action found"); } String preProcessFunction = predictAction.get().getPreProcessFunction(); - if (preProcessFunction == null) { - throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); - } + preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; if (MLPreProcessFunction.contains(preProcessFunction)) { Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); @@ -102,7 +100,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat if (processedInput.isEmpty()) { throw new IllegalArgumentException("Wrong input"); } - Map map = GsonUtil.fromJson(processedInput.get(), Map.class); + Map map = gson.fromJson(processedInput.get(), Map.class); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); } } @@ -116,7 +114,7 @@ private static Map convertScriptStringToJsonString(Map executeBuildInPostProcessFunction(List executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { - Map result = org.opensearch.ml.common.utils.StringUtils.fromJson(resultJson, "result"); + Map result = StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 98e90e668f..8e046b151c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -24,7 +24,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import java.io.IOException; @@ -37,6 +36,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; public class ConnectorUtilsTest { @@ -60,8 +60,6 @@ public void processInput_NullInput() { @Test public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); @@ -126,7 +124,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { List input = Collections.singletonList("test_value"); - String inputJson = GsonUtil.toJson(input); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts"); } @@ -136,7 +134,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() List input = new ArrayList<>(); input.add("test_value1"); input.add("test_value2"); - String inputJson = GsonUtil.toJson(input); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input"); } @@ -166,24 +164,6 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); } - @Test - public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOException { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map parameters = new HashMap<>(); - parameters.put("input", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); - ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); - } - @Test public void processOutput_PostprocessFunction() throws IOException { String postprocessResult = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[1536],\"data\":[-0.014555434, -2.135904E-4, 0.0035105038]}"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index a8290f0692..9caf621087 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -29,12 +29,15 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import java.io.IOException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -94,19 +97,25 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java index 24daf5ea8c..6ca1401efd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -39,7 +39,7 @@ public void test_executePreprocessFunction() { @Test public void test_executeBuildInPostProcessFunction() { List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); - List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.NEURAL_SEARCH_EMBEDDING)); + List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); assertNotNull(modelTensors); assertEquals(2, modelTensors.size()); } From 7b94c948fbb1e1bb7866f89753ac17837115b3d1 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 27 Sep 2023 11:08:14 +0800 Subject: [PATCH 5/5] Fix conflict when backport Signed-off-by: zane-neo --- common/build.gradle | 2 +- .../ml/common/output/model/ModelTensor.java | 6 +-- .../opensearch/ml/common/utils/GsonUtil.java | 30 ------------ .../ml/common/utils/GsonUtilsTest.java | 47 ------------------- .../org/opensearch/ml/engine/ModelHelper.java | 5 +- 5 files changed, 6 insertions(+), 84 deletions(-) delete mode 100644 common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java diff --git a/common/build.gradle b/common/build.gradle index 01f548c140..3aeb644d42 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -19,7 +19,7 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' - implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly group: 'org.json', name: 'json', version: '20230227' } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index c51031ac3d..8957f33643 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.output.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.nio.ByteBuffer; @@ -25,7 +26,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.utils.GsonUtil; @Data public class ModelTensor implements Writeable, ToXContentObject { @@ -224,7 +224,7 @@ public ModelTensor(StreamInput in) throws IOException { this.result = in.readOptionalString(); if (in.readBoolean()) { String mapStr = in.readString(); - this.dataAsMap = GsonUtil.fromJson(mapStr, Map.class); + this.dataAsMap = gson.fromJson(mapStr, Map.class); } } @@ -270,7 +270,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(true); try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - out.writeString(GsonUtil.toJson(dataAsMap)); + out.writeString(gson.toJson(dataAsMap)); return null; }); } catch (PrivilegedActionException e) { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java b/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java deleted file mode 100644 index b261d44124..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/utils/GsonUtil.java +++ /dev/null @@ -1,30 +0,0 @@ -package org.opensearch.ml.common.utils; - -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.stream.JsonReader; - -public class GsonUtil { - - private static final Gson gson; - - static { - gson = new Gson(); - } - - public static String toJson(Object obj) { - return gson.toJson(obj); - } - - public static T fromJson(String json, Class clazz) { - return gson.fromJson(json, clazz); - } - - public static T fromJson(JsonElement jsonElement, Class clazz) { - return gson.fromJson(jsonElement, clazz); - } - - public static T fromJson(JsonReader jsonReader, Class clazz) { - return gson.fromJson(jsonReader, clazz); - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java deleted file mode 100644 index c5ceb58917..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/utils/GsonUtilsTest.java +++ /dev/null @@ -1,47 +0,0 @@ -package org.opensearch.ml.common.utils; - -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import com.google.gson.stream.JsonReader; -import org.junit.Assert; -import org.junit.Test; - -import java.io.StringReader; -import java.util.HashMap; -import java.util.Map; - -public class GsonUtilsTest { - @Test - public void test_toJson() { - Map map = new HashMap<>(); - map.put("key", "value"); - String mapString = GsonUtil.toJson(map); - assert mapString.equals("{\"key\":\"value\"}"); - } - - @Test - public void test_fromJsonString() { - Map map = GsonUtil.fromJson("{\"key\": \"value\"}", Map.class); - Assert.assertEquals(1, map.size()); - Assert.assertEquals("value", map.get("key")); - } - - @Test - public void test_fromJsonJsonElement() { - JsonElement jsonElement = JsonParser.parseString("{\"key\": \"value\"}"); - Map map = GsonUtil.fromJson(jsonElement, Map.class); - Assert.assertEquals(1, map.size()); - Assert.assertEquals("value", map.get("key")); - } - - @Test - public void test_fromJsonJsonReader() { - JsonReader reader = new JsonReader(new StringReader("{\"key\": \"value\"}")); - Map map = GsonUtil.fromJson(reader, Map.class); - Assert.assertEquals(1, map.size()); - Assert.assertEquals("value", map.get("key")); - } - - - -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index ab32c1580b..2dfe1b28e8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -15,7 +15,6 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import org.opensearch.ml.common.utils.GsonUtil; import java.io.File; import java.io.FileReader; @@ -73,7 +72,7 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi Map config = null; try (JsonReader reader = new JsonReader(new FileReader(configCacheFilePath))) { - config = GsonUtil.fromJson(reader, Map.class); + config = gson.fromJson(reader, Map.class); } if (config == null) { @@ -171,7 +170,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re List config = null; try (JsonReader reader = new JsonReader(new FileReader(cacheFilePath))) { - config = GsonUtil.fromJson(reader, List.class); + config = gson.fromJson(reader, List.class); } return config;