Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support step size for embedding model which outputs less embeddings #1586

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

Expand All @@ -41,7 +42,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
}
processedDocs += Math.max(tensorCount, 1);
// This is to support some model which takes N text docs and embedding size is less than N-1.
// We need to tell executor what's the step size for each model run.
Map<String, String> parameters = getConnector().getParameters();
int stepSize = 1;
if (parameters != null) {
stepSize = Integer.parseInt(Optional.ofNullable(parameters.get("input_docs_processed_step_size")).orElse("1"));
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
}
processedDocs += Math.max(tensorCount, stepSize);
tensorOutputs.addAll(tempTensorOutputs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -165,7 +166,7 @@ public void executePredict_TextDocsInput() throws IOException {
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
Expand All @@ -182,6 +183,7 @@ public void executePredict_TextDocsInput() throws IOException {
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
when(executor.getConnector()).thenReturn(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand All @@ -190,4 +192,46 @@ public void executePredict_TextDocsInput() throws IOException {
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
}

@Test
public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException {
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": ${parameters.input}}")
.build();
Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "2");
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
// model takes 2 input docs, but only output 1 embedding
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
+ " } ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
when(executor.getConnector()).thenReturn(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we should provide 4 documents to see if we are getting 2 outputs? But we can do it later too.

ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
}
}
Loading