From 22b558d1f985d6de2c7b4713a6928fde73ea0030 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Tue, 11 Jun 2024 09:34:12 -0700 Subject: [PATCH] Fix model still deployed after calling undeploy API (#2510) * Fix model still deployed after calling undeploy API Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song * Fix style Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song --- .../TransportUndeployModelAction.java | 201 +++++----- .../TransportUndeployModelActionTests.java | 357 ++++++++++++++++-- .../ml/tools/ToolIntegrationWithLLMTest.java | 1 - 3 files changed, 437 insertions(+), 122 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 662971b2c7..6456039774 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -29,7 +29,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -42,10 +41,10 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; -import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -59,11 +58,8 @@ public class TransportUndeployModelAction extends private final MLModelManager mlModelManager; private final ClusterService clusterService; private final Client client; - private DiscoveryNodeHelper nodeFilter; + private final DiscoveryNodeHelper nodeFilter; private final MLStats mlStats; - private NamedXContentRegistry xContentRegistry; - - private ModelAccessControlHelper modelAccessControlHelper; @Inject public TransportUndeployModelAction( @@ -74,9 +70,7 @@ public TransportUndeployModelAction( ThreadPool threadPool, Client client, DiscoveryNodeHelper nodeFilter, - MLStats mlStats, - NamedXContentRegistry xContentRegistry, - ModelAccessControlHelper modelAccessControlHelper + MLStats mlStats ) { super( MLUndeployModelAction.NAME, @@ -90,107 +84,128 @@ public TransportUndeployModelAction( MLUndeployModelNodeResponse.class ); this.mlModelManager = mlModelManager; + this.clusterService = clusterService; this.client = client; this.nodeFilter = nodeFilter; this.mlStats = mlStats; - this.xContentRegistry = xContentRegistry; - this.modelAccessControlHelper = modelAccessControlHelper; } @Override - protected MLUndeployModelNodesResponse newResponse( - MLUndeployModelNodesRequest nodesRequest, - List responses, - List failures + protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener listener) { + ActionListener wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> { + processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener); + }, listener::onFailure); + super.doExecute(task, request, wrappedListener); + } + + void processUndeployModelResponseAndUpdate( + MLUndeployModelNodesResponse undeployModelNodesResponse, + ActionListener listener ) { - if (responses != null) { - Map> actualRemovedNodesMap = new HashMap<>(); - Map modelWorkNodesBeforeRemoval = new HashMap<>(); - responses.forEach(r -> { - Map nodeCounts = r.getModelWorkerNodeBeforeRemoval(); - - if (nodeCounts != null) { - for (Map.Entry entry : nodeCounts.entrySet()) { - // when undeploy a undeployed model, the entry.getvalue() is null - if (entry.getValue() != null - && (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) - || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { - modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue()); - } + List responses = undeployModelNodesResponse.getNodes(); + if (responses == null || responses.isEmpty()) { + listener.onResponse(undeployModelNodesResponse); + return; + } + + Map> actualRemovedNodesMap = new HashMap<>(); + Map modelWorkNodesBeforeRemoval = new HashMap<>(); + responses.forEach(r -> { + Map nodeCounts = r.getModelWorkerNodeBeforeRemoval(); + + if (nodeCounts != null) { + for (Map.Entry entry : nodeCounts.entrySet()) { + // when undeploy an undeployed model, the entry.getvalue() is null + if (entry.getValue() != null + && (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) + || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { + modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue()); } } + } - Map modelUndeployStatus = r.getModelUndeployStatus(); - for (Map.Entry entry : modelUndeployStatus.entrySet()) { - String status = entry.getValue(); - if (UNDEPLOYED.equals(status)) { - String modelId = entry.getKey(); - if (!actualRemovedNodesMap.containsKey(modelId)) { - actualRemovedNodesMap.put(modelId, new ArrayList<>()); - } - actualRemovedNodesMap.get(modelId).add(r.getNode().getId()); + Map modelUndeployStatus = r.getModelUndeployStatus(); + for (Map.Entry entry : modelUndeployStatus.entrySet()) { + String status = entry.getValue(); + if (UNDEPLOYED.equals(status)) { + String modelId = entry.getKey(); + if (!actualRemovedNodesMap.containsKey(modelId)) { + actualRemovedNodesMap.put(modelId, new ArrayList<>()); } + actualRemovedNodesMap.get(modelId).add(r.getNode().getId()); } - }); - - MLSyncUpInput syncUpInput = MLSyncUpInput - .builder() - .removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)) - .build(); - - MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (actualRemovedNodesMap.size() > 0) { - BulkRequest bulkRequest = new BulkRequest(); - Map deployToAllNodes = new HashMap<>(); - for (String modelId : actualRemovedNodesMap.keySet()) { - UpdateRequest updateRequest = new UpdateRequest(); - List removedNodes = actualRemovedNodesMap.get(modelId); - int removedNodeCount = removedNodes.size(); - /** - * If allow custom deploy is false, user can only undeploy all nodes and status is undeployed. - * If allow custom deploy is true, user can undeploy all nodes and status is undeployed, - * or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and - * we need to update both planning worker nodes (count) and current worker nodes (count) - * and deployToAllNodes value in model index. - */ - Map updateDocument = new HashMap<>(); - if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes. - updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of()); - updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); - updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); - updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED); - } else { // undeploy partial nodes. - // TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed, - // and the user could be undeploying not running model nodes, and we should update model status to deployed. - updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false); - List newPlanningWorkerNodes = Arrays - .stream(modelWorkNodesBeforeRemoval.get(modelId)) - .filter(x -> !removedNodes.contains(x)) - .collect(Collectors.toList()); - updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes); - updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); - updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); - deployToAllNodes.put(modelId, false); - } - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); - bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + } + }); + + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)) + .build(); + + MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (actualRemovedNodesMap.size() > 0) { + BulkRequest bulkRequest = new BulkRequest(); + Map deployToAllNodes = new HashMap<>(); + for (String modelId : actualRemovedNodesMap.keySet()) { + UpdateRequest updateRequest = new UpdateRequest(); + List removedNodes = actualRemovedNodesMap.get(modelId); + int removedNodeCount = removedNodes.size(); + /** + * If allow custom deploy is false, user can only undeploy all nodes and status is undeployed. + * If allow custom deploy is true, user can undeploy all nodes and status is undeployed, + * or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and + * we need to update both planning worker nodes (count) and current worker nodes (count) + * and deployToAllNodes value in model index. + */ + Map updateDocument = new HashMap<>(); + if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes. + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of()); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED); + } else { // undeploy partial nodes. + // TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed, + // and the user could be undeploying not running model nodes, and we should update model status to deployed. + updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false); + List newPlanningWorkerNodes = Arrays + .stream(modelWorkNodesBeforeRemoval.get(modelId)) + .filter(x -> !removedNodes.contains(x)) + .collect(Collectors.toList()); + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); + deployToAllNodes.put(modelId, false); } - syncUpInput.setDeployToAllNodes(deployToAllNodes); - ActionListener actionListener = ActionListener.wrap(r -> { - log - .debug( - "updated model state as undeployed for : {}", - Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) - ); - }, e -> { log.error("Failed to update model state as undeployed", e); }); - client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); })); - } else { - syncUpUndeployedModels(syncUpRequest); + updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); + bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); } + syncUpInput.setDeployToAllNodes(deployToAllNodes); + ActionListener actionListener = ActionListener.wrap(r -> { + log + .debug( + "updated model state as undeployed for : {}", + Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) + ); + }, e -> { log.error("Failed to update model state as undeployed", e); }); + client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { + syncUpUndeployedModels(syncUpRequest); + listener.onResponse(undeployModelNodesResponse); + })); + } else { + syncUpUndeployedModels(syncUpRequest); + listener.onResponse(undeployModelNodesResponse); } } + } + + @Override + protected MLUndeployModelNodesResponse newResponse( + MLUndeployModelNodesRequest nodesRequest, + List responses, + List failures + ) { return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java index 8de5cf5ec8..87d42f3847 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java @@ -9,11 +9,10 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.io.IOException; import java.net.InetAddress; @@ -25,34 +24,41 @@ import java.util.concurrent.ExecutorService; import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.Spy; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; -import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; -import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -77,6 +83,18 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { @Mock private Client client; + @Mock + ClusterState clusterState; + + @Mock + Task task; + + @Spy + ActionListener actionListener; + + @Mock + MLSyncUpNodeResponse syncUpNodeResponse; + @Mock private DiscoveryNodeHelper nodeFilter; @@ -93,10 +111,22 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { private TransportUndeployModelAction action; - private DiscoveryNode localNode; + DiscoveryNode localNode; + + private DiscoveryNode node1; + + private DiscoveryNode node2; + + DiscoveryNode[] nodesArray; + + @Mock + private MLUndeployModelNodesResponse undeployModelNodesResponse; @Mock - private ModelAccessControlHelper modelAccessControlHelper; + private TransportNodesAction transportNodesAction; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setup() throws IOException { @@ -105,24 +135,26 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.generic()).thenReturn(executorService); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); return null; }).when(executorService).execute(any(Runnable.class)); - action = new TransportUndeployModelAction( - transportService, - actionFilters, - mlModelManager, - clusterService, - null, - client, - nodeFilter, - mlStats, - xContentRegistry, - modelAccessControlHelper + action = spy( + new TransportUndeployModelAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + threadPool, + client, + nodeFilter, + mlStats + ) ); + localNode = new DiscoveryNode( "foo0", "foo0", @@ -131,8 +163,34 @@ public void setup() throws IOException { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); + + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(nodes); } public void testConstructor() { @@ -171,7 +229,23 @@ public void testNodeOperation() { assertNotNull(response); } - public void testNewResponseWithUndeployedModelStatus() { + public void testDoExecuteTransportUndeployedModelAction() { + MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + + action.doExecute(task, nodesRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLUndeployModelNodesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + } + + public void testProcessUndeployModelResponseAndUpdateNullResponse() { + when(undeployModelNodesResponse.getNodes()).thenReturn(null); + action.processUndeployModelResponseAndUpdate(undeployModelNodesResponse, actionListener); + } + + public void testProcessUndeployModelResponseAndUpdateResponse() { final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( new String[] { "nodeId1", "nodeId2" }, new String[] { "modelId1", "modelId2" } @@ -187,13 +261,240 @@ public void testNewResponseWithUndeployedModelStatus() { responses.add(response2); final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - assertNotNull(response); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(argumentCaptor.capture(), any()); - UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0); - assertEquals(ML_MODEL_INDEX, updateRequest.index()); - Map updateContent = updateRequest.doc().sourceAsMap(); - assertEquals(MLModelState.UNDEPLOYED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD)); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateBulkException() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Bulk request failed")); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateSyncUpException() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("SyncUp request failed")); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseDeployStatusWrong() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "wrong_status"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployPartialNodes() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus1 = new HashMap<>(); + modelToDeployStatus1.put("modelId1", "undeployed"); + Map modelToDeployStatus2 = new HashMap<>(); + modelToDeployStatus2.put("modelId1", "deployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus1, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus2, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployEmptyNodes() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] {}); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployNodeEntrySetNull() { + exceptionRule.expect(NullPointerException.class); + + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", null); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployModelWorkerNodeBeforeRemovalNull() { + exceptionRule.expect(NullPointerException.class); + + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + action.processUndeployModelResponseAndUpdate(response, actionListener); } public void testNewResponseWithNotFoundModelStatus() { diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index fe033645d6..3e7c2e64f4 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -72,7 +72,6 @@ public void stopMockLLM() { @After public void deleteModel() throws IOException { undeployModel(modelId); - waitModelUndeployed(modelId); deleteModel(client(), modelId, null); }