diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 732d28663b..9ffd45475b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -54,6 +54,7 @@ public class MLModelCache { @Setter private Boolean deployToAllNodes; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean isAutoDeploying; public MLModelCache() { targetWorkerNodes = ConcurrentHashMap.newKeySet(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 6230ad2944..0e4eb23e11 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -17,6 +17,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import org.apache.commons.lang3.BooleanUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.TokenBucket; @@ -61,7 +62,7 @@ public synchronized void initModelState( List targetWorkerNodes, boolean deployToAllNodes ) { - if (isModelRunningOnNode(modelId)) { + if (isModelRunningOnNode(modelId) && !isAutoDeploying(modelId)) { throw new MLLimitExceededException("Duplicate deploy model task"); } log.debug("init model state for model {}, state: {}", modelId, state); @@ -74,7 +75,7 @@ public synchronized void initModelState( modelCaches.put(modelId, modelCache); } - public synchronized void initModelStateLocal( + public synchronized void initModelStateAutoDeploy( String modelId, MLModelState state, FunctionName functionName, @@ -92,6 +93,7 @@ public synchronized void initModelStateLocal( modelCache.setDeployToAllNodes(false); modelCache.setLastAccessTime(Instant.now()); modelCaches.put(modelId, modelCache); + setIsAutoDeploying(modelId, true); } /** @@ -279,6 +281,28 @@ public Boolean getIsModelEnabled(String modelId) { return modelCache.getIsModelEnabled(); } + /** + * Set a flag to show if model is in auto deploying status + * + * @param modelId model id + * @param isModelAutoDeploying auto deploy flag + */ + public synchronized void setIsAutoDeploying(String modelId, Boolean isModelAutoDeploying) { + log.debug("Setting the auto deploying flag for Model {}", modelId); + getExistingModelCache(modelId).setIsAutoDeploying(isModelAutoDeploying); + } + + /** + * Check if model is in auto deploying. + * + * @param modelId model id + * @return true if model is auto deploying. + */ + public boolean isAutoDeploying(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + return modelCache != null && BooleanUtils.isTrue(modelCache.getIsAutoDeploying()); + } + /** * Set memory size estimation CPU/GPU * diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index d354fdf771..ce24619226 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -987,13 +987,14 @@ public void deployModel( if (!autoDeployModel) { modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); } else { - modelCacheHelper.initModelStateLocal(modelId, MLModelState.DEPLOYING, functionName, workerNodes); + modelCacheHelper.initModelStateAutoDeploy(modelId, MLModelState.DEPLOYING, functionName, workerNodes); } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { context.restore(); modelCacheHelper.removeAutoDeployModel(modelId); + modelCacheHelper.setIsAutoDeploying(modelId, false); }); if (!autoDeployModel) { checkAndAddRunningTask(mlTask, maxDeployTasksPerNode);