Skip to content

Commit

Permalink
add a flag to distinguish duplicate remote model auto deploy and tran… (
Browse files Browse the repository at this point in the history
opensearch-project#2410) (opensearch-project#2411)

* add a flag to distinguish duplicate remote model auto deploy and transport deploy

Signed-off-by: Xun Zhang <[email protected]>

* check for NPE for getIsAutoDeploying

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
(cherry picked from commit f82e148)

Co-authored-by: Xun Zhang <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and Zhangxunmt authored May 6, 2024
1 parent b3b36ed commit 618678f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class MLModelCache {
@Setter
private Boolean deployToAllNodes;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean isAutoDeploying;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import org.apache.commons.lang3.BooleanUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.TokenBucket;
Expand Down Expand Up @@ -61,7 +62,7 @@ public synchronized void initModelState(
List<String> targetWorkerNodes,
boolean deployToAllNodes
) {
if (isModelRunningOnNode(modelId)) {
if (isModelRunningOnNode(modelId) && !isAutoDeploying(modelId)) {
throw new MLLimitExceededException("Duplicate deploy model task");
}
log.debug("init model state for model {}, state: {}", modelId, state);
Expand All @@ -74,7 +75,7 @@ public synchronized void initModelState(
modelCaches.put(modelId, modelCache);
}

public synchronized void initModelStateLocal(
public synchronized void initModelStateAutoDeploy(
String modelId,
MLModelState state,
FunctionName functionName,
Expand All @@ -92,6 +93,7 @@ public synchronized void initModelStateLocal(
modelCache.setDeployToAllNodes(false);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
setIsAutoDeploying(modelId, true);
}

/**
Expand Down Expand Up @@ -279,6 +281,28 @@ public Boolean getIsModelEnabled(String modelId) {
return modelCache.getIsModelEnabled();
}

/**
* Set a flag to show if model is in auto deploying status
*
* @param modelId model id
* @param isModelAutoDeploying auto deploy flag
*/
public synchronized void setIsAutoDeploying(String modelId, Boolean isModelAutoDeploying) {
log.debug("Setting the auto deploying flag for Model {}", modelId);
getExistingModelCache(modelId).setIsAutoDeploying(isModelAutoDeploying);
}

/**
* Check if model is in auto deploying.
*
* @param modelId model id
* @return true if model is auto deploying.
*/
public boolean isAutoDeploying(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
return modelCache != null && BooleanUtils.isTrue(modelCache.getIsAutoDeploying());
}

/**
* Set memory size estimation CPU/GPU
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -983,13 +983,14 @@ public void deployModel(
if (!autoDeployModel) {
modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes);
} else {
modelCacheHelper.initModelStateLocal(modelId, MLModelState.DEPLOYING, functionName, workerNodes);
modelCacheHelper.initModelStateAutoDeploy(modelId, MLModelState.DEPLOYING, functionName, workerNodes);
}

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> {
context.restore();
modelCacheHelper.removeAutoDeployModel(modelId);
modelCacheHelper.setIsAutoDeploying(modelId, false);
});
if (!autoDeployModel) {
checkAndAddRunningTask(mlTask, maxDeployTasksPerNode);
Expand Down

0 comments on commit 618678f

Please sign in to comment.