diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 766d88aec1..1165628711 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -7,12 +7,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; -import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -52,7 +48,7 @@ public class MLTask implements ToXContentObject, Writeable { public static final String ERROR_FIELD = "error"; public static final String IS_ASYNC_TASK_FIELD = "is_async"; public static final String REMOTE_JOB_FIELD = "remote_job"; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB = CommonValue.VERSION_2_17_0; @Setter private String taskId; @@ -136,10 +132,9 @@ public MLTask(StreamInput input) throws IOException { this.user = null; } this.async = input.readBoolean(); - if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) { if (input.readBoolean()) { - String mapStr = input.readString(); - this.remoteJob = gson.fromJson(mapStr, Map.class); + this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue()); } } } @@ -170,17 +165,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeBoolean(async); - if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) { if (remoteJob != null) { out.writeBoolean(true); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - out.writeString(gson.toJson(remoteJob)); - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } + out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue); } else { out.writeBoolean(false); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java new file mode 100644 index 0000000000..e6e1f3838c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.common.transport.task; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLCancelBatchJobRequestTest { + private String taskId; + + @Before + public void setUp() { + taskId = "test_id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlCancelBatchJobRequest.writeTo(bytesStreamOutput); + MLCancelBatchJobRequest parsedTask = new MLCancelBatchJobRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedTask.getTaskId(), taskId); + } + + @Test + public void validate_Exception_NullTaskId() { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().build(); + + ActionRequestValidationException exception = mlCancelBatchJobRequest.validate(); + assertEquals("Validation Failed: 1: ML task id can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlCancelBatchJobRequest.writeTo(out); + } + }; + MLCancelBatchJobRequest result = MLCancelBatchJobRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlCancelBatchJobRequest); + assertEquals(result.getTaskId(), mlCancelBatchJobRequest.getTaskId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLCancelBatchJobRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java new file mode 100644 index 0000000000..4cb3837df4 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java @@ -0,0 +1,36 @@ +package org.opensearch.ml.common.transport.task; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLCancelBatchJobResponseTest { + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLCancelBatchJobResponse response = MLCancelBatchJobResponse.builder().status(RestStatus.OK).build(); + response.writeTo(bytesStreamOutput); + MLCancelBatchJobResponse parsedResponse = new MLCancelBatchJobResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(response.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void toXContentTest() throws IOException { + MLCancelBatchJobResponse mlCancelBatchJobResponse1 = MLCancelBatchJobResponse.builder().status(RestStatus.OK).build(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + mlCancelBatchJobResponse1.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"status\":\"OK\"}", jsonStr); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ef4f25c79a..ccceff3d68 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; @@ -286,7 +287,9 @@ public static SdkHttpFullRequest buildSdkRequest( } else { requestBody = RequestBody.empty(); } - if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) { + if (SdkHttpMethod.POST == method + && 0 == requestBody.optionalContentLength().get() + && !action.equals(CANCEL_BATCH_PREDICT.toString())) { log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 0b96ac7af4..ec6b089ebb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -8,13 +8,15 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import java.util.HashMap; import java.util.Map; +import java.util.Optional; -import org.apache.http.HttpStatus; +import org.apache.hc.core5.http.HttpStatus; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; @@ -153,12 +155,14 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener Optional + .ofNullable(parameters.get("TransformJobArn")) + .map(jobArn -> jobArn.substring(jobArn.lastIndexOf("/") + 1)) + .orElse(null) + ); RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); @@ -184,24 +188,33 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener { log.error("Failed to retrieve the ML model with the given ID", e); - actionListener.onFailure(e); + actionListener + .onFailure( + new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND) + ); })); } catch (Exception e) { - // fetch the connector - log.error("Unable to fetch status for ml task ", e); + log.error("Unable to fetch cancel batch job in ml task ", e); + throw new OpenSearchException("Unable to fetch cancel batch job in ml task " + e.getMessage()); } } private void executeConnector(Connector connector, MLInput mlInput, ActionListener actionListener) { - connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); - RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); - connectorExecutor.setScriptService(scriptService); - connectorExecutor.setClusterService(clusterService); - connectorExecutor.setClient(client); - connectorExecutor.setXContentRegistry(xContentRegistry); - connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(taskResponse, actionListener); - }, e -> { actionListener.onFailure(e); })); + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(taskResponse, actionListener); + }, e -> { actionListener.onFailure(e); })); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } } private void processTaskResponse(MLTaskResponse taskResponse, ActionListener actionListener) { @@ -213,7 +226,7 @@ private void processTaskResponse(MLTaskResponse taskResponse, ActionListener Optional + .ofNullable(parameters.get("TransformJobArn")) + .map(jobArn -> jobArn.substring(jobArn.lastIndexOf("/") + 1)) + .orElse(null) + ); RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); @@ -187,12 +194,15 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); } }, e -> { - log.error("Failed to retrieve the ML model with the given ID", e); - actionListener.onFailure(e); + log.error("Failed to retrieve the ML model for the given task ID", e); + actionListener + .onFailure( + new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND) + ); })); } catch (Exception e) { - // fetch the connector log.error("Unable to fetch status for ml task ", e); + throw new OpenSearchException("Unable to fetch status for ml task " + e.getMessage()); } } @@ -204,15 +214,21 @@ private void executeConnector( Map transformJob, ActionListener actionListener ) { - connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); - RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); - connectorExecutor.setScriptService(scriptService); - connectorExecutor.setClusterService(clusterService); - connectorExecutor.setClient(client); - connectorExecutor.setXContentRegistry(xContentRegistry); - connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener); - }, e -> { actionListener.onFailure(e); })); + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener); + }, e -> { actionListener.onFailure(e); })); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } } private void processTaskResponse( diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java new file mode 100644 index 0000000000..99d9fbf8a1 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -0,0 +1,309 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class CancelBatchJobTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + private ClusterService clusterService; + @Mock + private ScriptService scriptService; + @Mock + ClusterState clusterState; + + @Mock + private Metadata metaData; + + @Mock + ActionFilters actionFilters; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private EncryptorImpl encryptor; + + @Mock + ActionListener actionListener; + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CancelBatchJobTransportAction cancelBatchJobTransportAction; + MLCancelBatchJobRequest mlCancelBatchJobRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId("test_id").build(); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metaData).when(clusterState).metadata(); + + doReturn(true).when(metaData).hasIndex(anyString()); + + cancelBatchJobTransportAction = spy( + new CancelBatchJobTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + scriptService, + connectorAccessControlHelper, + encryptor, + mlTaskManager, + mlModelManager + ) + ); + + MLModel mlModel = mock(MLModel.class); + + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT_STATUS) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/DescribeTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build() + ) + ) + .build(); + + when(mlModel.getConnectorId()).thenReturn("testConnectorID"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + } + + public void testGetTask_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_IndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Index Not Found")); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } + + @Ignore + public void testGetTask_SuccessBatchPredictCancel() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("TransformJobStatus", "COMPLETED")).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + verify(actionListener).onResponse(any(MLCancelBatchJobResponse.class)); + } + + public void test_BatchPredictCancel_NoConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareMLTask(FunctionName functionName, MLTaskType mlTaskType, Map remoteJob) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId("taskID") + .modelId("testModelID") + .functionName(functionName) + .taskType(mlTaskType) + .remoteJob(remoteJob) + .build(); + XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 3e59258436..3707c89eae 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -6,28 +6,56 @@ package org.opensearch.ml.action.tasks; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.engine.encryptor.EncryptorImpl; @@ -56,6 +84,11 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { private ClusterService clusterService; @Mock private ScriptService scriptService; + @Mock + ClusterState clusterState; + + @Mock + private Metadata metaData; @Mock ActionFilters actionFilters; @@ -85,6 +118,16 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metaData).when(clusterState).metadata(); + + doReturn(true).when(metaData).hasIndex(anyString()); + getTaskTransportAction = spy( new GetTaskTransportAction( transportService, @@ -100,10 +143,46 @@ public void setup() throws IOException { ) ); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); + MLModel mlModel = mock(MLModel.class); + + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT_STATUS) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/DescribeTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build() + ) + ) + .build(); + + when(mlModel.getConnectorId()).thenReturn("testConnectorID"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + } public void testGetTask_NullResponse() { @@ -141,4 +220,115 @@ public void testGetTask_IndexNotFoundException() { verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); } + + @Ignore + public void testGetTask_SuccessBatchPredictStatus() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("TransformJobStatus", "COMPLETED")).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + verify(actionListener).onResponse(any(MLTaskGetResponse.class)); + } + + public void test_BatchPredictStatus_NoConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoModel() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareMLTask(FunctionName functionName, MLTaskType mlTaskType, Map remoteJob) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId("taskID") + .modelId("testModelID") + .functionName(functionName) + .taskType(mlTaskType) + .remoteJob(remoteJob) + .build(); + XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java new file mode 100644 index 0000000000..1498750e6a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLCancelBatchJobActionTests extends OpenSearchTestCase { + + private RestMLCancelBatchJobAction restMLCancelBatchJobAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLCancelBatchJobAction = new RestMLCancelBatchJobAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLCancelBatchJobAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCancelBatchJobAction mlCancelBatchJobAction = new RestMLCancelBatchJobAction(); + assertNotNull(mlCancelBatchJobAction); + } + + public void testGetName() { + String actionName = restMLCancelBatchJobAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_cancel_batch_action", actionName); + } + + public void testRoutes() { + List routes = restMLCancelBatchJobAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/tasks/{task_id}/_cancel_batch", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLCancelBatchJobAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCancelBatchJobRequest.class); + verify(client, times(1)).execute(eq(MLCancelBatchJobAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getTaskId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TASK_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index e09610140d..064008a9c4 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -51,7 +51,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.PredictMode; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -418,11 +418,22 @@ public void testValidateModelTensorOutputSuccess() { public void testValidateBatchPredictionSuccess() throws IOException { setupMocks(true, false, false, false); - RemoteInferenceInputDataSet remoteInputDataSet = RemoteInferenceInputDataSet.builder().predictMode(PredictMode.BATCH).build(); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest .builder() .modelId("test_model") - .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInputDataSet).build()) + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build()) .build(); Predictable predictor = mock(Predictable.class); when(predictor.isModelReady()).thenReturn(true); @@ -466,11 +477,22 @@ public void testValidateBatchPredictionSuccess() throws IOException { public void testValidateBatchPredictionFailure() throws IOException { setupMocks(true, false, false, false); - RemoteInferenceInputDataSet remoteInputDataSet = RemoteInferenceInputDataSet.builder().predictMode(PredictMode.BATCH).build(); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest .builder() .modelId("test_model") - .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInputDataSet).build()) + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build()) .build(); Predictable predictor = mock(Predictable.class); when(predictor.isModelReady()).thenReturn(true);