Skip to content

Commit

Permalink
Add more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Oct 25, 2023
1 parent 96f6361 commit 72c78ff
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect
if (MLPostProcessFunction.contains(postProcessFunction)) {
// in this case, we can use jsonpath to build a List<List<Float>> result from model response.
if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction);
List<List<Float>> vectors = JsonPath.read(modelResponse, responseFilter);
List<?> vectors = JsonPath.read(modelResponse, responseFilter);
List<ModelTensor> processedResponse = executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get(postProcessFunction));
return ModelTensors.builder().mlModelTensors(processedResponse).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -22,6 +19,8 @@
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorProtocols;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
Expand All @@ -41,6 +40,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -214,4 +214,96 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}

@Test
public void executePredict_BedRock_TextDocsInferenceInput() throws IOException {
String jsonString = "{\"embedding\": [-0.043945312,-0.18847656,-0.21679688]}";
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"inputText\": \"${parameters.inputText}\"}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING)
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol(ConnectorProtocols.AWS_SIGV4)
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(3, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length);
}

@Test
public void test_executePredict_InvalidPredictAction() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("no predict action found");
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol(ConnectorProtocols.AWS_SIGV4)
.parameters(parameters)
.credential(credential)
.actions(null)
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
}

@Test
public void test_executePredict_InvalidInputDataset() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong input type");
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"inputText\": \"${parameters.inputText}\"}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING)
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol(ConnectorProtocols.AWS_SIGV4)
.parameters(parameters)
.credential(credential)
.actions(List.of(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(null).build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ingest.TestTemplateService;
Expand All @@ -24,6 +25,7 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.utils.ScriptUtils;
import org.opensearch.script.ScriptService;

import java.io.IOException;
Expand All @@ -33,7 +35,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.utils.StringUtils.gson;
Expand Down Expand Up @@ -118,7 +122,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorUtils.processRemoteInput(mlInput);
Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input"));
assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input"));
}

@Test
Expand Down Expand Up @@ -158,10 +162,10 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response"));
assertEquals(1, tensors.getMlModelTensors().size());
assertEquals("response", tensors.getMlModelTensors().get(0).getName());
assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size());
assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
Expand All @@ -181,13 +185,22 @@ public void processOutput_PostprocessFunction() throws IOException {
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
assertEquals(1, tensors.getMlModelTensors().size());
assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap());
Assert.assertEquals(3, tensors.getMlModelTensors().get(0).getData().length);
Assert.assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]);
Assert.assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]);
Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]);
assertEquals(3, tensors.getMlModelTensors().get(0).getData().length);
assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]);
assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]);
assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]);
}

@Test
public void processInput_TextDocsInputDataSet_userDefinedScriptPreprocessFunction() {
List<String> input = Collections.singletonList("test_value");
String preprocessFunction = "mock user defined preprocess function";
when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"parameters\": {\"input\": \"test_value\"}}"));
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": \"${parameters.input}\"}", input, "test_value", preprocessFunction, "input");
}

private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, List<String> inputs, String expectedProcessedInput, String preProcessName, String resultKey) {
Expand All @@ -206,7 +219,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processTextDocsInput(dataSet, preProcessName, new HashMap<>(), scriptService);
Assert.assertNotNull(remoteInferenceInputDataSet.getParameters());
Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey));
assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ public void predict_ModelDeployed_WrongInput() {
remoteModel.predict(mlInput);
}

@Test
public void predict_ModelDeployed_NullInput() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("Input is null");
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.predict(null);
}

@Test
public void initModel_RuntimeException() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down

0 comments on commit 72c78ff

Please sign in to comment.