From 21d995279e618297ea423105f7928796d139d984 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 6 May 2024 10:25:53 -0700 Subject: [PATCH] =?UTF-8?q?avoid=20race=20condition=20in=20syncup=20model?= =?UTF-8?q?=20state=20refresh=20and=20handle=20NP=20of=20I=E2=80=A6=20(#24?= =?UTF-8?q?05)=20(#2408)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * avoid race condition in syncup model state refresh and handle NP of IsAutoDeployEnabled Signed-off-by: Xun Zhang * log the error message from syncUp response Signed-off-by: Xun Zhang * include the syncup response error messages as a string to help debug Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang (cherry picked from commit 21bf0792128cd2bfac9dc341c615b9f2b50e7b99) Co-authored-by: Xun Zhang --- .../opensearch/ml/cluster/MLSyncUpCron.java | 74 +++++++++++-------- .../ml/rest/RestMLUndeployModelAction.java | 17 +---- .../ml/task/MLPredictTaskRunner.java | 2 +- 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index b11fe7afc2..da39166742 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.time.Instant; import java.util.ArrayList; @@ -41,8 +42,9 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; -import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; -import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; @@ -97,6 +99,14 @@ public void run() { // gather running model/tasks on nodes client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> { List responses = r.getNodes(); + if (r.failures() != null && r.failures().size() != 0) { + log + .debug( + "Received {} failures in the sync up response on nodes. Error messages are {}", + r.failures().size(), + r.failures().stream().map(Exception::getMessage).collect(Collectors.joining(", ")) + ); + } // key is model id, value is set of worker node ids Map> modelWorkerNodes = new HashMap<>(); // key is task id, value is set of worker node ids @@ -143,7 +153,6 @@ public void run() { if (modelWorkerNodes.containsKey(modelId) && expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) { // this model has expired in all the nodes - modelWorkerNodes.remove(modelId); modelsToUndeploy.add(modelId); } } @@ -168,37 +177,44 @@ public void run() { MLSyncUpInput syncUpInput = inputBuilder.build(); MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput); // sync up running model/tasks on nodes - client - .execute( - MLSyncUpAction.INSTANCE, - syncUpRequest, - ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> { - log.error("Failed to sync model routing", ex); - }) - ); - // Undeploy expired models - undeployExpiredModels(modelsToUndeploy, modelWorkerNodes); + client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> { + log.debug("sync model routing job finished"); + if (!modelsToUndeploy.isEmpty()) { + // Undeploy expired models + undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels); + return; + } + // refresh model status + mlIndicesHandler + .initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> { + log.error("Failed to init model index", e); + })); + }, ex -> { log.error("Failed to sync model routing", ex); })); + }, e -> { log.error("Failed to sync model routing", e); })); + } + + private void undeployExpiredModels( + Set expiredModels, + Map> modelWorkerNodes, + Map> deployingModels + ) { + String[] targetNodeIds = getAllNodes(clusterService); + MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest( + expiredModels.toArray(new String[expiredModels.size()]), + targetNodeIds + ); + + client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> { + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse(); + if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) { + log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures()); + } - // refresh model status mlIndicesHandler .initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> { log.error("Failed to init model index", e); })); - }, e -> { log.error("Failed to sync model routing", e); })); - } - - private void undeployExpiredModels(Set expiredModels, Map> modelWorkerNodes) { - expiredModels.forEach(modelId -> { - String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]); - - MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest( - targetNodeIds, - new String[] { modelId } - ); - client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { - log.debug("model {} is un_deployed", modelId); - }, e -> { log.error("Failed to undeploy model {}", modelId, e); })); - }); + }, e -> { log.error("Failed to undeploy models {}", expiredModels, e); })); } @VisibleForTesting diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java index c895163e1c..0cc30752df 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java @@ -9,16 +9,14 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Locale; import org.apache.commons.lang3.ArrayUtils; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; @@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException { } targetNodeIds = nodeIds; } else { - targetNodeIds = getAllNodes(); + targetNodeIds = getAllNodes(clusterService); } if (ArrayUtils.isNotEmpty(modelIds)) { targetModelIds = modelIds; } } else { - targetNodeIds = getAllNodes(); + targetNodeIds = getAllNodes(clusterService); } return new MLUndeployModelsRequest(targetModelIds, targetNodeIds); } - - private String[] getAllNodes() { - Iterator iterator = clusterService.state().nodes().iterator(); - List nodeIds = new ArrayList<>(); - while (iterator.hasNext()) { - nodeIds.add(iterator.next().getId()); - } - return nodeIds.toArray(new String[0]); - } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 4841cb2b35..101d9c9244 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener