Skip to content

Commit

Permalink
avoid race condition in syncup model state refresh and handle NP of I… (
Browse files Browse the repository at this point in the history
#2405)

* avoid race condition in syncup model state refresh and handle NP of IsAutoDeployEnabled

Signed-off-by: Xun Zhang <[email protected]>

* log the error message from syncUp response

Signed-off-by: Xun Zhang <[email protected]>

* include the syncup response error messages as a string to help debug

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored May 6, 2024
1 parent 950f864 commit 21bf079
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 44 deletions.
74 changes: 45 additions & 29 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -41,8 +42,9 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -97,6 +99,14 @@ public void run() {
// gather running model/tasks on nodes
client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> {
List<MLSyncUpNodeResponse> responses = r.getNodes();
if (r.failures() != null && r.failures().size() != 0) {
log
.debug(
"Received {} failures in the sync up response on nodes. Error messages are {}",
r.failures().size(),
r.failures().stream().map(Exception::getMessage).collect(Collectors.joining(", "))
);
}
// key is model id, value is set of worker node ids
Map<String, Set<String>> modelWorkerNodes = new HashMap<>();
// key is task id, value is set of worker node ids
Expand Down Expand Up @@ -143,7 +153,6 @@ public void run() {
if (modelWorkerNodes.containsKey(modelId)
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
// this model has expired in all the nodes
modelWorkerNodes.remove(modelId);
modelsToUndeploy.add(modelId);
}
}
Expand All @@ -168,37 +177,44 @@ public void run() {
MLSyncUpInput syncUpInput = inputBuilder.build();
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
// sync up running model/tasks on nodes
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> {
log.error("Failed to sync model routing", ex);
})
);
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);
client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> {
log.debug("sync model routing job finished");
if (!modelsToUndeploy.isEmpty()) {
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels);
return;
}
// refresh model status
mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
}, ex -> { log.error("Failed to sync model routing", ex); }));
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(
Set<String> expiredModels,
Map<String, Set<String>> modelWorkerNodes,
Map<String, Set<String>> deployingModels
) {
String[] targetNodeIds = getAllNodes(clusterService);
MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest(
expiredModels.toArray(new String[expiredModels.size()]),
targetNodeIds
);

client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> {
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse();
if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) {
log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures());
}

// refresh model status
mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
expiredModels.forEach(modelId -> {
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);

MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
targetNodeIds,
new String[] { modelId }
);
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
log.debug("model {} is un_deployed", modelId);
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
});
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

import org.apache.commons.lang3.ArrayUtils;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException {
}
targetNodeIds = nodeIds;
} else {
targetNodeIds = getAllNodes();
targetNodeIds = getAllNodes(clusterService);
}
if (ArrayUtils.isNotEmpty(modelIds)) {
targetModelIds = modelIds;
}
} else {
targetNodeIds = getAllNodes();
targetNodeIds = getAllNodes(clusterService);
}

return new MLUndeployModelsRequest(targetModelIds, targetNodeIds);
}

private String[] getAllNodes() {
Iterator<DiscoveryNode> iterator = clusterService.state().nodes().iterator();
List<String> nodeIds = new ArrayList<>();
while (iterator.hasNext()) {
nodeIds.add(iterator.next().getId());
}
return nodeIds.toArray(new String[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
}

private boolean checkModelAutoDeployEnabled(MLModel mlModel) {
if (mlModel.getDeploySetting() == null) {
if (mlModel.getDeploySetting() == null || mlModel.getDeploySetting().getIsAutoDeployEnabled() == null) {
return true;
}
return mlModel.getDeploySetting().getIsAutoDeployEnabled();
Expand Down

0 comments on commit 21bf079

Please sign in to comment.