Skip to content

Commit

Permalink
applying multi-tenancy for forward request which is important for dep…
Browse files Browse the repository at this point in the history
…loy request (#3158)

* applying multi-tenancy for forward request which is important for deploy request

Signed-off-by: Dhrubo Saha <[email protected]>

* addressed comments

Signed-off-by: Dhrubo Saha <[email protected]>

* applying spotlessApply

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Oct 24, 2024
1 parent 6180017 commit 5cac357
Show file tree
Hide file tree
Showing 16 changed files with 270 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@Data
public class MLDeployModelInput implements Writeable {
private String modelId;
private String tenantId;
private String taskId;
private String modelContentHash;
private Integer nodeCount;
Expand All @@ -26,6 +27,8 @@ public class MLDeployModelInput implements Writeable {
public MLDeployModelInput(StreamInput in) throws IOException {
this.modelId = in.readString();
this.taskId = in.readString();
// todo: need to check BWC test
this.tenantId = in.readOptionalString();
this.modelContentHash = in.readOptionalString();
this.nodeCount = in.readInt();
this.coordinatingNodeId = in.readString();
Expand All @@ -34,9 +37,10 @@ public MLDeployModelInput(StreamInput in) throws IOException {
}

@Builder
public MLDeployModelInput(String modelId, String taskId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, Boolean isDeployToAllNodes, MLTask mlTask) {
public MLDeployModelInput(String modelId, String taskId, String tenantId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, Boolean isDeployToAllNodes, MLTask mlTask) {
this.modelId = modelId;
this.taskId = taskId;
this.tenantId = tenantId;
this.modelContentHash = modelContentHash;
this.nodeCount = nodeCount;
this.coordinatingNodeId = coordinatingNodeId;
Expand All @@ -51,6 +55,7 @@ public MLDeployModelInput() {
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeString(taskId);
out.writeOptionalString(tenantId);
out.writeOptionalString(modelContentHash);
out.writeInt(nodeCount);
out.writeString(coordinatingNodeId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class MLForwardInput implements Writeable {

private String taskId;
private String modelId;
private String tenantId;
private String workerNodeId;
private MLForwardRequestType requestType;
private MLTask mlTask;
Expand All @@ -32,11 +33,12 @@ public class MLForwardInput implements Writeable {
private MLRegisterModelInput registerModelInput;

@Builder(toBuilder = true)
public MLForwardInput(String taskId, String modelId, String workerNodeId, MLForwardRequestType requestType,
public MLForwardInput(String taskId, String modelId, String tenantId, String workerNodeId, MLForwardRequestType requestType,
MLTask mlTask, MLInput modelInput,
String error, String[] workerNodes, MLRegisterModelInput registerModelInput) {
this.taskId = taskId;
this.modelId = modelId;
this.tenantId = tenantId;
this.workerNodeId = workerNodeId;
this.requestType = requestType;
this.mlTask = mlTask;
Expand All @@ -49,6 +51,8 @@ public MLForwardInput(String taskId, String modelId, String workerNodeId, MLForw
public MLForwardInput(StreamInput in) throws IOException {
this.taskId = in.readOptionalString();
this.modelId = in.readOptionalString();
// todo: need to do BWC check
this.tenantId = in.readOptionalString();
this.workerNodeId = in.readOptionalString();
this.requestType = in.readEnum(MLForwardRequestType.class);
if (in.readBoolean()) {
Expand All @@ -68,6 +72,8 @@ public MLForwardInput(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(taskId);
out.writeOptionalString(modelId);
// TODO: need to do BWC check
out.writeOptionalString(tenantId);
out.writeOptionalString(workerNodeId);
out.writeEnum(requestType);
if (this.mlTask != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void setUp() throws Exception {
@Test
public void testConstructorSerialization1() throws IOException {
String [] nodeIds = {"id1", "id2", "id3"};
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", null, "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
);
Expand All @@ -104,7 +104,7 @@ public void testConstructorSerialization1() throws IOException {
@Test
public void testConstructorSerialization2() throws IOException {
DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3};
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", null, "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
);
Expand Down Expand Up @@ -140,7 +140,7 @@ public void testConstructorSerialization3() throws IOException {
@Test
public void testConstructorFromInputStream() throws IOException {
String [] nodeIds = {"id1", "id2", "id3"};
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", null, "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ private void indexAndCreateController(
String errorMessage = getErrorMessage("Model controller saved into index, result:{}", modelId, isHidden);
log.info(errorMessage, indexResponse.getResult());
if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) {
mlModelManager.updateModel(modelId, isHidden, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, true));
// we aren't enabling controller feature for multi-tenancy. So tenant id is null by default.
mlModelManager.updateModel(modelId, null, isHidden, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, true));
}
if (!ArrayUtils.isEmpty(mlModelCacheHelper.getWorkerNodes(modelId))) {
log
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public void onResponse(DeleteResponse deleteResponse) {
getErrorMessage("Model controller for the provided successfully deleted from index, result: {}", modelId, isHidden),
deleteResponse.getResult()
);
mlModelManager.updateModel(modelId, isHidden, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, false));
mlModelManager.updateModel(modelId, null, isHidden, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, false));
actionListener.onResponse(deleteResponse);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@ private void deployModel(
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
return;
}
mlTaskManager.add(mlTask, eligibleNodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
Expand All @@ -302,6 +298,7 @@ private void deployModel(
() -> updateModelDeployStatusAndTriggerOnNodesAction(
modelId,
taskId,
tenantId,
mlModel,
localNodeId,
mlTask,
Expand Down Expand Up @@ -344,6 +341,7 @@ void deployRemoteModel(
MLDeployModelInput deployModelInput = new MLDeployModelInput(
mlModel.getModelId(),
mlTask.getTaskId(),
mlModel.getTenantId(),
mlModel.getModelContentHash(),
eligibleNodes.size(),
localNodeId,
Expand All @@ -367,6 +365,7 @@ void deployRemoteModel(
mlModelManager
.updateModel(
mlModel.getModelId(),
mlModel.getTenantId(),
Map
.of(
MLModel.MODEL_STATE_FIELD,
Expand Down Expand Up @@ -408,7 +407,7 @@ private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListe
TASK_SEMAPHORE_TIMEOUT,
true
);
mlModelManager.updateModel(modelId, isHidden, Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED));
mlModelManager.updateModel(modelId, tenantId, isHidden, Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED));
listener.onFailure(e);
});
}
Expand All @@ -417,6 +416,7 @@ private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListe
void updateModelDeployStatusAndTriggerOnNodesAction(
String modelId,
String taskId,
String tenantId,
MLModel mlModel,
String localNodeId,
MLTask mlTask,
Expand All @@ -426,6 +426,7 @@ void updateModelDeployStatusAndTriggerOnNodesAction(
MLDeployModelInput deployModelInput = new MLDeployModelInput(
modelId,
taskId,
tenantId,
mlModel.getModelContentHash(),
eligibleNodes.size(),
localNodeId,
Expand All @@ -451,13 +452,20 @@ void updateModelDeployStatusAndTriggerOnNodesAction(
TASK_SEMAPHORE_TIMEOUT,
true
);
mlModelManager.updateModel(modelId, mlModel.getIsHidden(), Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED));
mlModelManager
.updateModel(
modelId,
mlModel.getTenantId(),
mlModel.getIsHidden(),
Map.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)
);
});

List<String> workerNodes = eligibleNodes.stream().map(DiscoveryNode::getId).collect(Collectors.toList());
mlModelManager
.updateModel(
modelId,
mlModel.getTenantId(),
Map
.of(
MLModel.MODEL_STATE_FIELD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
MLDeployModelInput deployModelInput = MLDeployModelNodesRequest.getMlDeployModelInput();
String modelId = deployModelInput.getModelId();
String taskId = deployModelInput.getTaskId();
String tenantId = deployModelInput.getTenantId();
String coordinatingNodeId = deployModelInput.getCoordinatingNodeId();
MLTask mlTask = deployModelInput.getMlTask();
String modelContentHash = deployModelInput.getModelContentHash();
Expand All @@ -140,12 +141,13 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
String localNodeId = clusterService.localNode().getId();

ActionListener<MLForwardResponse> taskDoneListener = ActionListener
.wrap(res -> { log.info("deploy model task done " + taskId); }, ex -> {
.wrap(res -> { log.info("deploy model task done {}", taskId); }, ex -> {
logException("Deploy model task failed: " + taskId, ex, log);
});

deployModel(
modelId,
tenantId,
modelContentHash,
mlTask.getFunctionName(),
localNodeId,
Expand All @@ -158,6 +160,7 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
.requestType(MLForwardRequestType.DEPLOY_MODEL_DONE)
.taskId(taskId)
.modelId(modelId)
.tenantId(tenantId)
.workerNodeId(clusterService.localNode().getId())
.build();
MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
Expand All @@ -177,6 +180,7 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
.requestType(MLForwardRequestType.DEPLOY_MODEL_DONE)
.taskId(taskId)
.modelId(modelId)
.tenantId(tenantId)
.workerNodeId(clusterService.localNode().getId())
.error(MLExceptionUtils.getRootCauseMessage(e))
.build();
Expand Down Expand Up @@ -211,6 +215,7 @@ private DiscoveryNode getNodeById(String nodeId) {

private void deployModel(
String modelId,
String tenantId,
String modelContentHash,
FunctionName functionName,
String localNodeId,
Expand All @@ -224,7 +229,7 @@ private void deployModel(
mlModelManager
.deployModel(
modelId,
null,
tenantId,
modelContentHash,
functionName,
deployToAllNodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.sdk.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -62,6 +65,7 @@ public class TransportForwardAction extends HandledTransportAction<ActionRequest
private final ClusterService clusterService;
final MLTaskManager mlTaskManager;
final Client client;
final SdkClient sdkClient;
final MLModelManager mlModelManager;
final DiscoveryNodeHelper nodeHelper;

Expand All @@ -73,26 +77,32 @@ public class TransportForwardAction extends HandledTransportAction<ActionRequest

private final MLModelAutoReDeployer mlModelAutoReDeployer;

private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportForwardAction(
TransportService transportService,
ActionFilters actionFilters,
MLTaskManager mlTaskManager,
Client client,
SdkClient sdkClient,
MLModelManager mlModelManager,
DiscoveryNodeHelper nodeHelper,
Settings settings,
ClusterService clusterService,
MLModelAutoReDeployer mlModelAutoReDeployer
MLModelAutoReDeployer mlModelAutoReDeployer,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLForwardAction.NAME, transportService, actionFilters, MLForwardRequest::new);
this.mlTaskManager = mlTaskManager;
this.client = client;
this.sdkClient = sdkClient;
this.mlModelManager = mlModelManager;
this.nodeHelper = nodeHelper;
this.settings = settings;
this.clusterService = clusterService;
this.mlModelAutoReDeployer = mlModelAutoReDeployer;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

modelAutoRedeploySuccessRatio = ML_COMMONS_MODEL_AUTO_REDEPLOY_SUCCESS_RATIO.get(settings);
enableAutoReDeployModel = ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.get(settings);
Expand All @@ -107,6 +117,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
MLForwardRequest mlForwardRequest = MLForwardRequest.fromActionRequest(request);
MLForwardInput forwardInput = mlForwardRequest.getForwardInput();
String modelId = forwardInput.getModelId();
String tenantId = forwardInput.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, listener)) {
return;
}
String taskId = forwardInput.getTaskId();
MLRegisterModelInput registerModelInput = forwardInput.getRegisterModelInput();
MLTask mlTask = forwardInput.getMlTask();
Expand Down Expand Up @@ -147,7 +161,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
builder.put(MLTask.ERROR_FIELD, toJsonString(mlTaskCache.getErrors()));
}
boolean clearAutoReDeployRetryTimes = triggerNextModelDeployAndCheckIfRestRetryTimes(workNodes, taskId);
mlTaskManager.updateMLTask(taskId, null, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);
mlTaskManager.updateMLTask(taskId, tenantId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);

MLModelState modelState;
if (!mlTaskCache.allNodeFailed()) {
Expand All @@ -171,8 +185,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
} else {
log.error("Failed to update ML model {}, status: {}", modelId, response.status());
}
}, e -> { log.error("Failed to update ML model: " + modelId, e); });
mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
}, e -> log.error("Failed to update ML model: {}", modelId, e));
mlModelManager.updateModel(modelId, tenantId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
mlModelManager.removeAutoDeployModel(modelId);
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -402,7 +401,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
);
});
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlTaskManager.add(mlTask, Arrays.asList(nodeId));
mlTaskManager.add(mlTask, List.of(nodeId));
MLForwardInput forwardInput = MLForwardInput
.builder()
.requestType(MLForwardRequestType.REGISTER_MODEL)
Expand Down
Loading

0 comments on commit 5cac357

Please sign in to comment.