diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 2366b150d4..4e312bd16e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; @@ -41,7 +42,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); } - processedDocs += Math.max(tensorCount, 1); + // This is to support some model which takes N text docs and embedding size is less than N-1. + // We need to tell executor what's the step size for each model run. + Map parameters = getConnector().getParameters(); + if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { + processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size")); + } else { + processedDocs += Math.max(tensorCount, 1); + } tensorOutputs.addAll(tempTensorOutputs); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 11ae20c470..d12bc8ed60 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -39,6 +39,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -165,7 +166,7 @@ public void executePredict_TextDocsInput() throws IOException { .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) .requestBody("{\"input\": ${parameters.input}}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); @@ -182,6 +183,7 @@ public void executePredict_TextDocsInput() throws IOException { HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); @@ -190,4 +192,46 @@ public void executePredict_TextDocsInput() throws IOException { Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } + + @Test + public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException { + String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; + String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map parameters = ImmutableMap.of("input_docs_processed_step_size", "2"); + HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setScriptService(scriptService); + when(httpClient.execute(any())).thenReturn(response); + // model takes 2 input docs, but only output 1 embedding + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + HttpEntity entity = new StringEntity(modelResponse); + when(response.getEntity()).thenReturn(entity); + when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + } }