diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 6fc69621af..786cc6bdca 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -105,6 +105,9 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); tensors.setStatusCode(statusCode); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 92a44a3d91..704d6e3e05 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -115,7 +115,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput( docs.add(null); } } - if (preProcessFunction.contains("${parameters")) { + if (preProcessFunction.contains("${parameters.")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction); } @@ -186,7 +186,7 @@ public static ModelTensors processOutput( // execute user defined painless script. Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response); if (responseFilter == null) { connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 3dea04a7e7..b089dfffe1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -23,6 +23,8 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.util.EntityUtils; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; @@ -104,9 +106,13 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S return null; }); String modelResponse = responseRef.get(); + Integer statusCode = statusCodeRef.get(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCodeRef.get()); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 0ff8f9a91e..cf99934779 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,15 +32,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); - preparePayloadAndInvokeRemoteModel( - MLInput - .builder() - .algorithm(FunctionName.TEXT_EMBEDDING) - .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) - .build(), - tensorOutputs - ); + int processedDocs = 0; + while(processedDocs < textDocsInputDataSet.getDocs().size()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + List tempTensorOutputs = new ArrayList<>(); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); + processedDocs += Math.max(tempTensorOutputs.size(), 1); + tensorOutputs.addAll(tempTensorOutputs); + } } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 4cdb4387c2..f3bdbf0644 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -140,6 +140,36 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } + @Test + public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}"); + String jsonString = "{\"message\":\"The security token included in the request is invalid\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(403); + when(response.httpResponse()).thenReturn(httpResponse); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + @Test public void executePredict_RemoteInferenceInput() throws IOException { String jsonString = "{\"key\":\"value\"}"; @@ -219,9 +249,8 @@ public void executePredict_TextDocsInferenceInput() throws IOException { connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 4c4e20b6d4..d8cdb4e9d5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -26,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; @@ -137,9 +138,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert @@ -149,6 +149,28 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti ); } + @Test + public void executePredict_TextDocsInput_LimitExceed() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); + when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; @@ -190,9 +212,8 @@ public void executePredict_TextDocsInput() throws IOException { when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert .assertArrayEquals( diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index a024591f21..dd1deac4ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -942,8 +942,8 @@ public void deployModel( CLUSTER_SERVICE, clusterService ); - // deploy remote model or model trained by built-in algorithm like kmeans - if (mlModel.getConnector() != null) { + // deploy remote model with internal connector or model trained by built-in algorithm like kmeans + if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupPredictable(modelId, mlModel, params); wrappedListener.onResponse("successful"); return; @@ -952,6 +952,7 @@ public void deployModel( GetRequest getConnectorRequest = new GetRequest(); FetchSourceContext fetchContext = new FetchSourceContext(true, null, null); getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext); + // get connector and deploy remote model with standalone connector client.get(getConnectorRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { try ( diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 55313bc986..348618773a 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -213,9 +213,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe FunctionName algorithm = mlInput.getAlgorithm(); // run predict if (modelId != null) { - try { - Predictable predictor = mlModelManager.getPredictor(modelId); - if (predictor != null) { + Predictable predictor = mlModelManager.getPredictor(modelId); + if (predictor != null) { + try { if (!predictor.isModelReady()) { throw new IllegalArgumentException("Model not ready: " + modelId); } @@ -229,11 +229,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe MLTaskResponse response = MLTaskResponse.builder().output(output).build(); internalListener.onResponse(response); return; - } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { - throw new IllegalArgumentException("Model not ready to be used: " + modelId); + } catch (Exception e) { + handlePredictFailure(mlTask, internalListener, e, false, modelId); + return; } - } catch (Exception e) { - handlePredictFailure(mlTask, internalListener, e, false, modelId); + } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + throw new IllegalArgumentException("Model not ready to be used: " + modelId); } // search model by model id. @@ -252,6 +253,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe GetResponse getResponse = r; String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); MLModel mlModel = MLModel.parse(xContentParser, algorithmName); + mlModel.setModelId(modelId); User resourceUser = mlModel.getUser(); User requestUser = getUserContext(client); if (!checkUserPermissions(requestUser, resourceUser, modelId)) { @@ -263,7 +265,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } // run predict - mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + if (mlTaskManager.contains(mlTask.getTaskId())) { + mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + } MLOutput output = mlEngine.predict(mlInput, mlModel); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());