Skip to content

Commit

Permalink
change to model group access for batch job task APIs (#3098)
Browse files Browse the repository at this point in the history
* change to model group access for batch job task APIs

Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna authored Oct 12, 2024
1 parent 5682d11 commit 6277410
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -42,6 +43,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -54,9 +56,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -73,6 +77,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -88,6 +93,7 @@ public CancelBatchJobTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -99,6 +105,7 @@ public CancelBatchJobTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -177,25 +184,39 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model with the given ID", e);
actionListener
Expand All @@ -211,26 +232,20 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
}

private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -59,6 +60,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -71,9 +73,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -111,6 +116,7 @@ public GetTaskTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -123,6 +129,7 @@ public GetTaskTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -238,26 +245,40 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to access this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model for the given task ID", e);
actionListener
Expand All @@ -280,26 +301,20 @@ private void executeConnector(
Map<String, Object> remoteJob,
ActionListener<MLTaskGetResponse> actionListener
) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

protected void processTaskResponse(
Expand Down
Loading

0 comments on commit 6277410

Please sign in to comment.