diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index aa471ba3fe..d742d05970 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -11,6 +11,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -48,7 +49,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); } - processedDocs += Math.max(tensorCount, 1); + // This is to support some model which takes N text docs and embedding size is less than N-1. + // We need to tell executor what's the step size for each model run. + Map parameters = getConnector().getParameters(); + if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { + processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size")); + } else { + processedDocs += Math.max(tensorCount, 1); + } tensorOutputs.addAll(tempTensorOutputs); } } else { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index e91110b74e..ba1f434689 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -42,6 +44,9 @@ import org.opensearch.script.ScriptService; import com.google.common.collect.ImmutableMap; +import java.util.Map; + + public class HttpJsonConnectorExecutorTest { @Rule @@ -192,22 +197,15 @@ public void executePredict_TextDocsInput() throws IOException { .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); @@ -244,6 +242,7 @@ public void executePredict_TextDocsInput() throws IOException { HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -261,4 +260,46 @@ public void executePredict_TextDocsInput() throws IOException { modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData() ); } + + @Test + public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException { + String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; + String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map parameters = ImmutableMap.of("input_docs_processed_step_size", "2"); + HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setScriptService(scriptService); + when(httpClient.execute(any())).thenReturn(response); + // model takes 2 input docs, but only output 1 embedding + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + HttpEntity entity = new StringEntity(modelResponse); + when(response.getEntity()).thenReturn(entity); + when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + } }