From 060bcab7509ff155ce4d49912f81ab1c787116ad Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 3 Nov 2023 02:24:52 -0700 Subject: [PATCH] validate step size (#1587) Signed-off-by: Yaliang Wu --- .../remote/RemoteConnectorExecutor.java | 11 +- .../remote/HttpJsonConnectorExecutorTest.java | 167 ++++++++++++++---- 2 files changed, 138 insertions(+), 40 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 7b6aa94754..c73f4ac1c7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -48,11 +48,18 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); } - // This is to support some model which takes N text docs and embedding size is less than N-1. + // This is to support some model which takes N text docs and embedding size is less than N. // We need to tell executor what's the step size for each model run. Map parameters = getConnector().getParameters(); if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { - processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size")); + int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); + // We need to check the parameter on runtime as parameter can be passed into predict request + if (stepSize <= 0) { + throw new IllegalArgumentException( + "Invalid parameter: input_docs_processed_step_size. It must be positive integer." + ); + } + processedDocs += stepSize; } else { processedDocs += Math.max(tensorCount, 1); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index cc6a923d09..818a4184cb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -5,15 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.util.Arrays; - +import com.google.common.collect.ImmutableMap; import org.apache.http.HttpEntity; import org.apache.http.ProtocolVersion; import org.apache.http.StatusLine; @@ -43,9 +35,14 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.script.ScriptService; -import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Arrays; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -195,15 +192,22 @@ public void executePredict_TextDocsInput() throws IOException { .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); - HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); @@ -264,29 +268,49 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); Map parameters = ImmutableMap.of("input_docs_processed_step_size", "2"); - HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); // model takes 2 input docs, but only output 1 embedding - String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" - + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" - + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" - + " } ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" + " }\n" + "}"; + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); @@ -294,10 +318,77 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE when(executor.getHttpClient()).thenReturn(httpClient); when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert + .assertArrayEquals( + new Number[] { -0.014555434, -0.002135904, 0.0035105038 }, + modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData() + ); + } + + @Test + public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); + String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; + String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException + Map parameters = ImmutableMap.of("input_docs_processed_step_size", "-1"); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setScriptService(scriptService); + when(httpClient.execute(any())).thenReturn(response); + // model takes 2 input docs, but only output 1 embedding + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + HttpEntity entity = new StringEntity(modelResponse); + when(response.getEntity()).thenReturn(entity); + when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } }