Skip to content

Commit

Permalink
validate step size (#1587)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and dhrubo-os committed Dec 1, 2023
1 parent 855fd66 commit 86d6de1
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -49,11 +48,18 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
}
// This is to support some model which takes N text docs and embedding size is less than N-1.
// This is to support some model which takes N text docs and embedding size is less than N.
// We need to tell executor what's the step size for each model run.
Map<String, String> parameters = getConnector().getParameters();
if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size"));
int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
// We need to check the parameter on runtime as parameter can be passed into predict request
if (stepSize <= 0) {
throw new IllegalArgumentException(
"Invalid parameter: input_docs_processed_step_size. It must be positive integer."
);
}
processedDocs += stepSize;
} else {
processedDocs += Math.max(tensorCount, 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
Expand Down Expand Up @@ -44,9 +43,6 @@
import org.opensearch.script.ScriptService;

import com.google.common.collect.ImmutableMap;
import java.util.Map;



public class HttpJsonConnectorExecutorTest {
@Rule
Expand Down Expand Up @@ -197,15 +193,22 @@ public void executePredict_TextDocsInput() throws IOException {
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
HttpConnector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
Expand Down Expand Up @@ -266,40 +269,127 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "2");
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
HttpConnector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
// model takes 2 input docs, but only output 1 embedding
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
+ " } ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
String modelResponse = "{\n"
+ " \"object\": \"list\",\n"
+ " \"data\": [\n"
+ " {\n"
+ " \"object\": \"embedding\",\n"
+ " \"index\": 0,\n"
+ " \"embedding\": [\n"
+ " -0.014555434,\n"
+ " -0.002135904,\n"
+ " 0.0035105038\n"
+ " ]\n"
+ " } ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n"
+ " \"usage\": {\n"
+ " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n"
+ " }\n"
+ "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
when(executor.getConnector()).thenReturn(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
ModelTensorOutput modelTensorOutput = executor
.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert
.assertArrayEquals(
new Number[] { -0.014555434, -0.002135904, 0.0035105038 },
modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()
);
}

@Test
public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
// step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException
Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "-1");
HttpConnector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
// model takes 2 input docs, but only output 1 embedding
String modelResponse = "{\n"
+ " \"object\": \"list\",\n"
+ " \"data\": [\n"
+ " {\n"
+ " \"object\": \"embedding\",\n"
+ " \"index\": 0,\n"
+ " \"embedding\": [\n"
+ " -0.014555434,\n"
+ " -0.002135904,\n"
+ " 0.0035105038\n"
+ " ]\n"
+ " } ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n"
+ " \"usage\": {\n"
+ " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n"
+ " }\n"
+ "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
when(executor.getConnector()).thenReturn(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor
.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
}

0 comments on commit 86d6de1

Please sign in to comment.