diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index b32c6243ab..b9bbff7db6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -28,6 +28,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -42,6 +43,7 @@ import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -54,9 +56,11 @@ import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -73,6 +77,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction getModelListener = ActionListener.wrap(model -> { - if (model.getConnector() != null) { - Connector connector = model.getConnector(); - executeConnector(connector, mlInput, actionListener); - } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { - ActionListener listener = ActionListener - .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { - log.error("Failed to get connector " + model.getConnectorId(), e); - actionListener.onFailure(e); - }); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper - .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); + } else { + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener + .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector( + client, + model.getConnectorId(), + ActionListener.runBefore(listener, threadContext::restore) + ); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } } - } else { - actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); - } + }, e -> { + log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); + actionListener.onFailure(e); + })); }, e -> { log.error("Failed to retrieve the ML model with the given ID", e); actionListener @@ -211,26 +232,20 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener actionListener) { - if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { - Optional cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); - if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { - ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT); - connector.addAction(connectorAction); - } - connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); - RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader - .initInstance(connector.getProtocol(), connector, Connector.class); - connectorExecutor.setScriptService(scriptService); - connectorExecutor.setClusterService(clusterService); - connectorExecutor.setClient(client); - connectorExecutor.setXContentRegistry(xContentRegistry); - connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(taskResponse, actionListener); - }, e -> { actionListener.onFailure(e); })); - } else { - actionListener - .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + Optional cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); + if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT); + connector.addAction(connectorAction); } + connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(taskResponse, actionListener); + }, e -> { actionListener.onFailure(e); })); } private void processTaskResponse(MLTaskResponse taskResponse, ActionListener actionListener) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 35e4e6d83d..fc6b3429a4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -45,6 +45,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -59,6 +60,7 @@ import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -71,9 +73,11 @@ import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction getModelListener = ActionListener.wrap(model -> { - if (model.getConnector() != null) { - Connector connector = model.getConnector(); - executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); - } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { - ActionListener listener = ActionListener.wrap(connector -> { - executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); - }, e -> { - log.error("Failed to get connector " + model.getConnectorId(), e); - actionListener.onFailure(e); - }); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper - .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + actionListener.onFailure(new MLValidationException("You don't have permission to access this batch job")); + } else { + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector( + client, + model.getConnectorId(), + ActionListener.runBefore(listener, threadContext::restore) + ); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } } - } else { - actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); - } + }, e -> { + log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); + actionListener.onFailure(e); + })); }, e -> { log.error("Failed to retrieve the ML model for the given task ID", e); actionListener @@ -280,26 +301,20 @@ private void executeConnector( Map remoteJob, ActionListener actionListener ) { - if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { - Optional batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name()); - if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) { - ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS); - connector.addAction(connectorAction); - } - connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); - RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader - .initInstance(connector.getProtocol(), connector, Connector.class); - connectorExecutor.setScriptService(scriptService); - connectorExecutor.setClusterService(clusterService); - connectorExecutor.setClient(client); - connectorExecutor.setXContentRegistry(xContentRegistry); - connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener); - }, e -> { actionListener.onFailure(e); })); - } else { - actionListener - .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + Optional batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name()); + if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS); + connector.addAction(connectorAction); } + connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener); + }, e -> { actionListener.onFailure(e); })); } protected void processTaskResponse( diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 0c6939ea77..00d755b248 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -60,6 +60,7 @@ import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; @@ -95,6 +96,8 @@ public class CancelBatchJobTransportActionTests extends OpenSearchTestCase { ActionFilters actionFilters; @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private ModelAccessControlHelper modelAccessControlHelper; @Mock private EncryptorImpl encryptor; @@ -141,6 +144,7 @@ public void setup() throws IOException { clusterService, scriptService, connectorAccessControlHelper, + modelAccessControlHelper, encryptor, mlTaskManager, mlModelManager, @@ -180,7 +184,11 @@ public void setup() throws IOException { return null; }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); - when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -272,12 +280,16 @@ public void testGetTask_SuccessBatchPredictCancel() throws IOException { verify(actionListener).onResponse(any(MLCancelBatchJobResponse.class)); } - public void test_BatchPredictCancel_NoConnector() throws IOException { + public void test_BatchPredictCancel_NoModelGroupAccess() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); - when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -290,10 +302,10 @@ public void test_BatchPredictCancel_NoConnector() throws IOException { cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + assertEquals("You don't have permission to cancel this batch job", argumentCaptor.getValue().getMessage()); } - public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + public void test_BatchPredictStatus_NoConnectorFound() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 1c036adb92..8f655fe59e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -70,6 +70,7 @@ import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; @@ -106,6 +107,9 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + @Mock private EncryptorImpl encryptor; @@ -173,6 +177,7 @@ public void setup() throws IOException { clusterService, scriptService, connectorAccessControlHelper, + modelAccessControlHelper, encryptor, mlTaskManager, mlModelManager, @@ -221,7 +226,11 @@ public void setup() throws IOException { return null; }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); - when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -291,12 +300,16 @@ public void testGetTask_SuccessBatchPredictStatus() throws IOException { verify(actionListener).onResponse(any(MLTaskGetResponse.class)); } - public void test_BatchPredictStatus_NoConnector() throws IOException { + public void test_BatchPredictStatus_NoModelGroupAccess() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); - when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -309,7 +322,7 @@ public void test_BatchPredictStatus_NoConnector() throws IOException { getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + assertEquals("You don't have permission to access this batch job", argumentCaptor.getValue().getMessage()); } public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { @@ -317,7 +330,11 @@ public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { remoteJob.put("Status", "IN PROGRESS"); remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); - when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -337,7 +354,7 @@ public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { ); } - public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + public void test_BatchPredictStatus_NoConnectorFound() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); remoteJob.put("TransformJobName", "SM-offline-batch-transform13");