diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 5674d16c12..09c7b8dec9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -148,7 +148,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect if (MLPostProcessFunction.contains(postProcessFunction)) { // in this case, we can use jsonpath to build a List> result from model response. if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); - List> vectors = JsonPath.read(modelResponse, responseFilter); + List vectors = JsonPath.read(modelResponse, responseFilter); List processedResponse = executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get(postProcessFunction)); return ModelTensors.builder().mlModelTensors(processedResponse).build(); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 8b0d5a8173..b69654d96e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,9 +7,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.apache.http.ProtocolVersion; -import org.apache.http.StatusLine; -import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -22,6 +19,8 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; @@ -41,6 +40,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -214,4 +214,96 @@ public void executePredict_TextDocsInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void executePredict_BedRock_TextDocsInferenceInput() throws IOException { + String jsonString = "{\"embedding\": [-0.043945312,-0.18847656,-0.21679688]}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"inputText\": \"${parameters.inputText}\"}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol(ConnectorProtocols.AWS_SIGV4) + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(3, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + } + + @Test + public void test_executePredict_InvalidPredictAction() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("no predict action found"); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol(ConnectorProtocols.AWS_SIGV4) + .parameters(parameters) + .credential(credential) + .actions(null) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + } + + @Test + public void test_executePredict_InvalidInputDataset() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Wrong input type"); + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"inputText\": \"${parameters.inputText}\"}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol(ConnectorProtocols.AWS_SIGV4) + .parameters(parameters) + .credential(credential) + .actions(List.of(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(null).build()); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 1bb72ec68d..3b75c44f07 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -11,6 +11,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.ingest.TestTemplateService; @@ -24,6 +25,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.utils.ScriptUtils; import org.opensearch.script.ScriptService; import java.io.IOException; @@ -33,7 +35,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -118,7 +122,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); ConnectorUtils.processRemoteInput(mlInput); - Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); + assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @Test @@ -158,10 +162,10 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio parameters.put("key1", "value1"); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of()); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); + assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); } @Test @@ -181,13 +185,22 @@ public void processOutput_PostprocessFunction() throws IOException { Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); - Assert.assertEquals(3, tensors.getMlModelTensors().get(0).getData().length); - Assert.assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]); - Assert.assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]); - Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); + assertEquals(3, tensors.getMlModelTensors().get(0).getData().length); + assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]); + assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]); + assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); + } + + @Test + public void processInput_TextDocsInputDataSet_userDefinedScriptPreprocessFunction() { + List input = Collections.singletonList("test_value"); + String preprocessFunction = "mock user defined preprocess function"; + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"parameters\": {\"input\": \"test_value\"}}")); + processInput_TextDocsInputDataSet_PreprocessFunction( + "{\"input\": \"${parameters.input}\"}", input, "test_value", preprocessFunction, "input"); } private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, List inputs, String expectedProcessedInput, String preProcessName, String resultKey) { @@ -206,7 +219,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processTextDocsInput(dataSet, preProcessName, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); - Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); - Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); + assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); + assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 6016748a1e..fad7a26af9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -75,6 +75,16 @@ public void predict_ModelDeployed_WrongInput() { remoteModel.predict(mlInput); } + @Test + public void predict_ModelDeployed_NullInput() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("Input is null"); + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.predict(null); + } + @Test public void initModel_RuntimeException() { exceptionRule.expect(IllegalArgumentException.class);