From b5b05eeb1a13ec05ab3da65325d244e59df91757 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 3 Nov 2023 02:24:52 -0700 Subject: [PATCH] validate step size (#1587) Signed-off-by: Yaliang Wu --- .../remote/RemoteConnectorExecutor.java | 9 +++- .../remote/HttpJsonConnectorExecutorTest.java | 41 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 7b6aa94754..c473525194 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -48,11 +48,16 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); } - // This is to support some model which takes N text docs and embedding size is less than N-1. + // This is to support some model which takes N text docs and embedding size is less than N. // We need to tell executor what's the step size for each model run. Map parameters = getConnector().getParameters(); if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { - processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size")); + int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); + // We need to check the parameter on runtime as parameter can be passed into predict request + if (stepSize <= 0) { + throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); + } + processedDocs += stepSize; } else { processedDocs += Math.max(tensorCount, 1); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index cc6a923d09..7780408fc5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -300,4 +300,45 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); } + + @Test + public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); + String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; + String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException + Map parameters = ImmutableMap.of("input_docs_processed_step_size", "-1"); + HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setScriptService(scriptService); + when(httpClient.execute(any())).thenReturn(response); + // model takes 2 input docs, but only output 1 embedding + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + HttpEntity entity = new StringEntity(modelResponse); + when(response.getEntity()).thenReturn(entity); + when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } }