diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 6456039774..52d83ce767 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; +import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import java.io.IOException; import java.util.ArrayList; @@ -14,6 +15,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.opensearch.action.FailedNodeException; @@ -29,6 +32,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -44,7 +48,6 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; -import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -80,7 +83,7 @@ public TransportUndeployModelAction( actionFilters, MLUndeployModelNodesRequest::new, MLUndeployModelNodeRequest::new, - ThreadPool.Names.MANAGEMENT, + DEPLOY_THREAD_POOL, MLUndeployModelNodeResponse.class ); this.mlModelManager = mlModelManager; @@ -92,23 +95,14 @@ public TransportUndeployModelAction( } @Override - protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener listener) { - ActionListener wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> { - processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener); - }, listener::onFailure); - super.doExecute(task, request, wrappedListener); - } - - void processUndeployModelResponseAndUpdate( - MLUndeployModelNodesResponse undeployModelNodesResponse, - ActionListener listener + protected MLUndeployModelNodesResponse newResponse( + MLUndeployModelNodesRequest nodesRequest, + List responses, + List failures ) { - List responses = undeployModelNodesResponse.getNodes(); - if (responses == null || responses.isEmpty()) { - listener.onResponse(undeployModelNodesResponse); - return; + if (CollectionUtils.isEmpty(responses)) { + return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures); } - Map> actualRemovedNodesMap = new HashMap<>(); Map modelWorkNodesBeforeRemoval = new HashMap<>(); responses.forEach(r -> { @@ -116,10 +110,10 @@ void processUndeployModelResponseAndUpdate( if (nodeCounts != null) { for (Map.Entry entry : nodeCounts.entrySet()) { - // when undeploy an undeployed model, the entry.getvalue() is null + // when undeploy a undeployed model, the entry.getvalue() is null if (entry.getValue() != null && (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) - || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { + || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue()); } } @@ -144,8 +138,9 @@ void processUndeployModelResponseAndUpdate( .build(); MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); + CountDownLatch countDownLatch = new CountDownLatch(1); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (actualRemovedNodesMap.size() > 0) { + if (!actualRemovedNodesMap.isEmpty()) { BulkRequest bulkRequest = new BulkRequest(); Map deployToAllNodes = new HashMap<>(); for (String modelId : actualRemovedNodesMap.keySet()) { @@ -188,24 +183,32 @@ void processUndeployModelResponseAndUpdate( "updated model state as undeployed for : {}", Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) ); - }, e -> { log.error("Failed to update model state as undeployed", e); }); + countDownLatch.countDown(); + }, e -> { + log.error("Failed to update model state as undeployed", e); + countDownLatch.countDown(); + }); client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); - listener.onResponse(undeployModelNodesResponse); + context.restore(); })); } else { syncUpUndeployedModels(syncUpRequest); - listener.onResponse(undeployModelNodesResponse); + context.restore(); + countDownLatch.countDown(); + } + } + if (countDownLatch.getCount() != 0) { + try { + boolean success = countDownLatch.await(1000, TimeUnit.MILLISECONDS); + if (!success) { + log.error("Failed to update model state as undeployed in model index after waiting for 1 second, please check model status manually"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.error("Failed to update model state as undeployed in model index, current thread is interrupted", e); } } - } - - @Override - protected MLUndeployModelNodesResponse newResponse( - MLUndeployModelNodesRequest nodesRequest, - List responses, - List failures - ) { return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java index 87d42f3847..9d8de9530f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java @@ -7,12 +7,14 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.io.IOException; import java.net.InetAddress; @@ -24,33 +26,28 @@ import java.util.concurrent.ExecutorService; import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.Spy; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; -import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; -import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; -import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; @@ -58,7 +55,6 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; -import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -83,27 +79,12 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { @Mock private Client client; - @Mock - ClusterState clusterState; - - @Mock - Task task; - - @Spy - ActionListener actionListener; - - @Mock - MLSyncUpNodeResponse syncUpNodeResponse; - @Mock private DiscoveryNodeHelper nodeFilter; @Mock private MLStats mlStats; - @Mock - NamedXContentRegistry xContentRegistry; - private ThreadContext threadContext; @Mock @@ -111,22 +92,7 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { private TransportUndeployModelAction action; - DiscoveryNode localNode; - - private DiscoveryNode node1; - - private DiscoveryNode node2; - - DiscoveryNode[] nodesArray; - - @Mock - private MLUndeployModelNodesResponse undeployModelNodesResponse; - - @Mock - private TransportNodesAction transportNodesAction; - - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + private DiscoveryNode localNode; @Before public void setup() throws IOException { @@ -135,26 +101,22 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.generic()).thenReturn(executorService); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); return null; }).when(executorService).execute(any(Runnable.class)); - action = spy( - new TransportUndeployModelAction( - transportService, - actionFilters, - mlModelManager, - clusterService, - threadPool, - client, - nodeFilter, - mlStats - ) + action = new TransportUndeployModelAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + null, + client, + nodeFilter, + mlStats ); - localNode = new DiscoveryNode( "foo0", "foo0", @@ -163,34 +125,8 @@ public void setup() throws IOException { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); - - InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); - InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); - - DiscoveryNode node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(inetAddress1, 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - - DiscoveryNode node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(inetAddress2, 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - - DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); - when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); when(clusterService.localNode()).thenReturn(localNode); - when(clusterService.state()).thenReturn(clusterState); - when(clusterState.nodes()).thenReturn(nodes); } public void testConstructor() { @@ -229,23 +165,7 @@ public void testNodeOperation() { assertNotNull(response); } - public void testDoExecuteTransportUndeployedModelAction() { - MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - - action.doExecute(task, nodesRequest, actionListener); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLUndeployModelNodesResponse.class); - verify(actionListener).onResponse(argCaptor.capture()); - } - - public void testProcessUndeployModelResponseAndUpdateNullResponse() { - when(undeployModelNodesResponse.getNodes()).thenReturn(null); - action.processUndeployModelResponseAndUpdate(undeployModelNodesResponse, actionListener); - } - - public void testProcessUndeployModelResponseAndUpdateResponse() { + public void testNewResponseWithUndeployedModelStatus() { final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( new String[] { "nodeId1", "nodeId2" }, new String[] { "modelId1", "modelId2" } @@ -261,60 +181,35 @@ public void testProcessUndeployModelResponseAndUpdateResponse() { responses.add(response2); final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - BulkResponse bulkResponse = mock(BulkResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(bulkResponse); - return null; - }).when(client).bulk(any(), any()); - - MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(syncUpNodesResponse); - return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); + assertNotNull(response); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(argumentCaptor.capture(), any()); + UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + Map updateContent = updateRequest.doc().sourceAsMap(); + assertEquals(MLModelState.UNDEPLOYED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD)); } - public void testProcessUndeployModelResponseAndUpdateBulkException() { + public void testNewResponseWithNotFoundModelStatus() { final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( new String[] { "nodeId1", "nodeId2" }, new String[] { "modelId1", "modelId2" } ); final List responses = new ArrayList<>(); Map modelToDeployStatus = new HashMap<>(); - modelToDeployStatus.put("modelId1", "undeployed"); Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + modelToDeployStatus.put("modelId1", "not_found"); + modelWorkerNodeCounts.put("modelId1", new String[] { "node" }); MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); responses.add(response1); responses.add(response2); final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("Bulk request failed")); - return null; - }).when(client).bulk(any(), any()); - - MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(syncUpNodesResponse); - return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); + assertNotNull(response); } - public void testProcessUndeployModelResponseAndUpdateSyncUpException() { + public void test_whenBulkRequestFinished_countDownBy1() { final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( new String[] { "nodeId1", "nodeId2" }, new String[] { "modelId1", "modelId2" } @@ -330,189 +225,15 @@ public void testProcessUndeployModelResponseAndUpdateSyncUpException() { responses.add(response2); final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - BulkResponse bulkResponse = mock(BulkResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(bulkResponse); - return null; - }).when(client).bulk(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("SyncUp request failed")); - return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); - } - - public void testProcessUndeployModelResponseAndUpdateResponseDeployStatusWrong() { - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus = new HashMap<>(); - modelToDeployStatus.put("modelId1", "wrong_status"); - Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - BulkResponse bulkResponse = mock(BulkResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(bulkResponse); - return null; - }).when(client).bulk(any(), any()); - - MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(syncUpNodesResponse); - return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); - } - - public void testProcessUndeployModelResponseAndUpdateResponseUndeployPartialNodes() { - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus1 = new HashMap<>(); - modelToDeployStatus1.put("modelId1", "undeployed"); - Map modelToDeployStatus2 = new HashMap<>(); - modelToDeployStatus2.put("modelId1", "deployed"); - Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus1, modelWorkerNodeCounts); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus2, modelWorkerNodeCounts); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - BulkResponse bulkResponse = mock(BulkResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(bulkResponse); - return null; - }).when(client).bulk(any(), any()); - - MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(syncUpNodesResponse); - return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); - } - - public void testProcessUndeployModelResponseAndUpdateResponseUndeployEmptyNodes() { - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus = new HashMap<>(); - modelToDeployStatus.put("modelId1", "undeployed"); - Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[] {}); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - BulkResponse bulkResponse = mock(BulkResponse.class); + assertNotNull(response); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(bulkResponse); - return null; - }).when(client).bulk(any(), any()); - - MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(syncUpNodesResponse); + listener.onResponse(new BulkResponse(new BulkItemResponse[0], 0)); return null; - }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - verify(actionListener).onResponse(response); - } - - public void testProcessUndeployModelResponseAndUpdateResponseUndeployNodeEntrySetNull() { - exceptionRule.expect(NullPointerException.class); - - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus = new HashMap<>(); - modelToDeployStatus.put("modelId1", "undeployed"); - Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", null); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - } - - public void testProcessUndeployModelResponseAndUpdateResponseUndeployModelWorkerNodeBeforeRemovalNull() { - exceptionRule.expect(NullPointerException.class); - - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus = new HashMap<>(); - modelToDeployStatus.put("modelId1", "undeployed"); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - - action.processUndeployModelResponseAndUpdate(response, actionListener); - } - - public void testNewResponseWithNotFoundModelStatus() { - final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( - new String[] { "nodeId1", "nodeId2" }, - new String[] { "modelId1", "modelId2" } - ); - final List responses = new ArrayList<>(); - Map modelToDeployStatus = new HashMap<>(); - Map modelWorkerNodeCounts = new HashMap<>(); - modelToDeployStatus.put("modelId1", "not_found"); - modelWorkerNodeCounts.put("modelId1", new String[] { "node" }); - MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); - responses.add(response1); - responses.add(response2); - final List failures = new ArrayList<>(); - final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - assertNotNull(response); + }).when(client).bulk(any(), isA(ActionListener.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ActionListener.class); + verify(client, times(1)).bulk(any(), argumentCaptor.capture()); + argumentCaptor.getValue().onResponse(new BulkResponse(new BulkItemResponse[0], 0)); + argumentCaptor.getValue().onFailure(new Exception("error")); } }