diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadata.java
index 441417641e03c..b4a727e587f4a 100644
--- a/server/src/main/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadata.java
+++ b/server/src/main/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadata.java
@@ -10,6 +10,7 @@
 
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.AbstractDiffable;
+import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.Diff;
 import org.elasticsearch.cluster.DiffableUtils;
 import org.elasticsearch.cluster.NamedDiff;
@@ -27,6 +28,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.TreeMap;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -66,6 +68,13 @@ public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput in) throws IOE
         return new NodeShutdownMetadataDiff(in);
     }
 
+    public static Optional<NodesShutdownMetadata> getShutdowns(final ClusterState state) {
+        assert state != null : "cluster state should never be null";
+        return Optional.ofNullable(state)
+            .map(ClusterState::metadata)
+            .map(m -> m.custom(TYPE));
+    }
+
     private final Map<String, SingleNodeShutdownMetadata> nodes;
 
     public NodesShutdownMetadata(Map<String, SingleNodeShutdownMetadata> nodes) {
@@ -84,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException {
     /**
      * @return A map of NodeID to shutdown metadata.
      */
-    public Map<String, SingleNodeShutdownMetadata> getAllNodeMetdataMap() {
+    public Map<String, SingleNodeShutdownMetadata> getAllNodeMetadataMap() {
         return Collections.unmodifiableMap(nodes);
     }
 
diff --git a/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java b/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java
index d9cd58eff2a92..c406ea8db0fff 100644
--- a/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java
+++ b/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java
@@ -25,6 +25,7 @@
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
@@ -33,6 +34,7 @@
 import java.util.Set;
 import java.util.function.Function;
 import java.util.function.Predicate;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
@@ -155,6 +157,14 @@ public ImmutableOpenMap<String, DiscoveryNode> getCoordinatingOnlyNodes() {
         return nodes.build();
     }
 
+    /**
+     * Return all the nodes as a collection
+     * @return
+     */
+    public Collection<DiscoveryNode> getAllNodes() {
+        return StreamSupport.stream(this.spliterator(), false).collect(Collectors.toList());
+    }
+
     /**
      * Returns a stream of all nodes, with master nodes at the front
      */
diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java
index dd524886dfb1c..1ef0f57b7a633 100644
--- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java
+++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java
@@ -19,9 +19,11 @@
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
 import org.elasticsearch.cluster.NotMasterException;
 import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
@@ -33,7 +35,9 @@
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.io.Closeable;
+import java.util.List;
 import java.util.Objects;
+import java.util.stream.Collectors;
 
 /**
  * Component that runs only on the master node and is responsible for assigning running tasks to nodes
@@ -48,7 +52,7 @@ public class PersistentTasksClusterService implements ClusterStateListener, Clos
 
     private final ClusterService clusterService;
     private final PersistentTasksExecutorRegistry registry;
-    private final EnableAssignmentDecider decider;
+    private final EnableAssignmentDecider enableDecider;
     private final ThreadPool threadPool;
     private final PeriodicRechecker periodicRechecker;
 
@@ -56,7 +60,7 @@ public PersistentTasksClusterService(Settings settings, PersistentTasksExecutorR
                                          ThreadPool threadPool) {
         this.clusterService = clusterService;
         this.registry = registry;
-        this.decider = new EnableAssignmentDecider(settings, clusterService.getClusterSettings());
+        this.enableDecider = new EnableAssignmentDecider(settings, clusterService.getClusterSettings());
         this.threadPool = threadPool;
         this.periodicRechecker = new PeriodicRechecker(CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
         if (DiscoveryNode.isMasterNode(settings)) {
@@ -298,12 +302,37 @@ private <Params extends PersistentTaskParams> Assignment createAssignment(final
                                                                               final ClusterState currentState) {
         PersistentTasksExecutor<Params> persistentTasksExecutor = registry.getPersistentTaskExecutorSafe(taskName);
 
-        AssignmentDecision decision = decider.canAssign();
+        AssignmentDecision decision = enableDecider.canAssign();
         if (decision.getType() == AssignmentDecision.Type.NO) {
             return unassignedAssignment("persistent task [" + taskName + "] cannot be assigned [" + decision.getReason() + "]");
         }
 
-        return persistentTasksExecutor.getAssignment(taskParams, currentState);
+        // Filter all nodes that are marked as shutting down, because we do not
+        // want to assign a persistent task to a node that will shortly be
+        // leaving the cluster
+        final List<DiscoveryNode> candidateNodes = currentState.nodes().getAllNodes().stream()
+            .filter(dn -> isNodeShuttingDown(currentState, dn.getId()) == false)
+            .collect(Collectors.toList());
+        // Task assignment should not rely on node order
+        Randomness.shuffle(candidateNodes);
+
+        final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState);
+        assert (assignment == null || isNodeShuttingDown(currentState, assignment.getExecutorNode()) == false) :
+            "expected task [" + taskName + "] to be assigned to a node that is not marked as shutting down, but " +
+                assignment.getExecutorNode() + " is currently marked as shutting down";
+        return assignment;
+    }
+
+    /**
+     * Returns true if the given node is marked as shutting down with any
+     * shutdown type.
+     */
+    static boolean isNodeShuttingDown(final ClusterState state, final String nodeId) {
+        // Right now we make no distinction between the type of shutdown, but maybe in the future we might?
+        return NodesShutdownMetadata.getShutdowns(state)
+            .map(NodesShutdownMetadata::getAllNodeMetadataMap)
+            .map(allNodes -> allNodes.get(nodeId))
+            .isPresent();
     }
 
     @Override
diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java
index 158bd4a8d4eb7..9dcb851f28a89 100644
--- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java
+++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java
@@ -15,6 +15,7 @@
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
 import org.elasticsearch.tasks.TaskId;
 
+import java.util.Collection;
 import java.util.Map;
 import java.util.function.Predicate;
 
@@ -41,10 +42,10 @@ public String getTaskName() {
     /**
      * Returns the node id where the params has to be executed,
      * <p>
-     * The default implementation returns the least loaded data node
+     * The default implementation returns the least loaded data node from amongst the collection of candidate nodes
      */
-    public Assignment getAssignment(Params params, ClusterState clusterState) {
-        DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, DiscoveryNode::canContainData);
+    public Assignment getAssignment(Params params, Collection<DiscoveryNode> candidateNodes, ClusterState clusterState) {
+        DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, candidateNodes, DiscoveryNode::canContainData);
         if (discoveryNode == null) {
             return NO_NODE_FOUND;
         } else {
@@ -53,13 +54,16 @@ public Assignment getAssignment(Params params, ClusterState clusterState) {
     }
 
     /**
-     * Finds the least loaded node that satisfies the selector criteria
+     * Finds the least loaded node from amongs the candidate node collection
+     * that satisfies the selector criteria
      */
-    protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState, Predicate<DiscoveryNode> selector) {
+    protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState,
+                                                  Collection<DiscoveryNode> candidateNodes,
+                                                  Predicate<DiscoveryNode> selector) {
         long minLoad = Long.MAX_VALUE;
         DiscoveryNode minLoadedNode = null;
         PersistentTasksCustomMetadata persistentTasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
-        for (DiscoveryNode node : clusterState.getNodes()) {
+        for (DiscoveryNode node : candidateNodes) {
             if (selector.test(node)) {
                 if (persistentTasks == null) {
                     // We don't have any task running yet, pick the first available node
diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadataTests.java
index db8fb03fd1da9..1499e8af41baa 100644
--- a/server/src/test/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadataTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/metadata/NodesShutdownMetadataTests.java
@@ -37,8 +37,8 @@ public void testInsertNewNodeShutdownMetadata() {
 
         nodesShutdownMetadata = nodesShutdownMetadata.putSingleNodeMetadata(newNodeMetadata);
 
-        assertThat(nodesShutdownMetadata.getAllNodeMetdataMap().get(newNodeMetadata.getNodeId()), equalTo(newNodeMetadata));
-        assertThat(nodesShutdownMetadata.getAllNodeMetdataMap().values(), contains(newNodeMetadata));
+        assertThat(nodesShutdownMetadata.getAllNodeMetadataMap().get(newNodeMetadata.getNodeId()), equalTo(newNodeMetadata));
+        assertThat(nodesShutdownMetadata.getAllNodeMetadataMap().values(), contains(newNodeMetadata));
     }
 
     public void testRemoveShutdownMetadata() {
@@ -52,9 +52,9 @@ public void testRemoveShutdownMetadata() {
         SingleNodeShutdownMetadata nodeToRemove = randomFrom(nodes);
         nodesShutdownMetadata = nodesShutdownMetadata.removeSingleNodeMetadata(nodeToRemove.getNodeId());
 
-        assertThat(nodesShutdownMetadata.getAllNodeMetdataMap().get(nodeToRemove.getNodeId()), nullValue());
-        assertThat(nodesShutdownMetadata.getAllNodeMetdataMap().values(), hasSize(nodes.size() - 1));
-        assertThat(nodesShutdownMetadata.getAllNodeMetdataMap().values(), not(hasItem(nodeToRemove)));
+        assertThat(nodesShutdownMetadata.getAllNodeMetadataMap().get(nodeToRemove.getNodeId()), nullValue());
+        assertThat(nodesShutdownMetadata.getAllNodeMetadataMap().values(), hasSize(nodes.size() - 1));
+        assertThat(nodesShutdownMetadata.getAllNodeMetadataMap().values(), not(hasItem(nodeToRemove)));
     }
 
     @Override
diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java
index 2b3188bbab266..fba5d6d43bd07 100644
--- a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java
@@ -8,7 +8,6 @@
 
 package org.elasticsearch.persistent;
 
-import com.carrotsearch.hppc.cursors.ObjectCursor;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
@@ -18,14 +17,18 @@
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
+import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.RoutingTable;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.TriFunction;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata.Assignment;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
 import org.elasticsearch.persistent.TestPersistentTasksPlugin.TestParams;
@@ -40,15 +43,18 @@
 import org.junit.Before;
 import org.junit.BeforeClass;
 
-import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.BiFunction;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singleton;
@@ -101,7 +107,7 @@ public void tearDown() throws Exception {
     }
 
     public void testReassignmentRequired() {
-        final PersistentTasksClusterService service = createService((params, clusterState) ->
+        final PersistentTasksClusterService service = createService((params, candidateNodes, clusterState) ->
             "never_assign".equals(((TestParams) params).getTestParam()) ? NO_NODE_FOUND : randomNodeAssignment(clusterState.nodes())
         );
 
@@ -166,7 +172,8 @@ public void testReassignmentRequiredOnMetadataChanges() {
 
         final ClusterChangedEvent event = new ClusterChangedEvent("test", current, previous);
 
-        final PersistentTasksClusterService service = createService((params, clusterState) -> randomNodeAssignment(clusterState.nodes()));
+        final PersistentTasksClusterService service = createService((params, candidateNodes, clusterState) ->
+            randomNodeAssignment(clusterState.nodes()));
         assertThat(dumpEvent(event), service.shouldReassignPersistentTasks(event), equalTo(changed && unassigned));
     }
 
@@ -426,7 +433,7 @@ public void testPeriodicRecheck() throws Exception {
         boolean shouldSimulateFailure = randomBoolean();
         ClusterService recheckTestClusterService = createRecheckTestClusterService(clusterState, shouldSimulateFailure);
         PersistentTasksClusterService service = createService(recheckTestClusterService,
-            (params, currentState) -> assignBasedOnNonClusterStateCondition(currentState.nodes()));
+            (params, candidateNodes, currentState) -> assignBasedOnNonClusterStateCondition(candidateNodes));
 
         ClusterChangedEvent event = new ClusterChangedEvent("test", clusterState, initialState);
         service.clusterChanged(event);
@@ -474,7 +481,7 @@ public void testPeriodicRecheckOffMaster() {
 
         ClusterService recheckTestClusterService = createRecheckTestClusterService(clusterState, false);
         PersistentTasksClusterService service = createService(recheckTestClusterService,
-            (params, currentState) -> assignBasedOnNonClusterStateCondition(currentState.nodes()));
+            (params, candidateNodes, currentState) -> assignBasedOnNonClusterStateCondition(candidateNodes));
 
         ClusterChangedEvent event = new ClusterChangedEvent("test", clusterState, initialState);
         service.clusterChanged(event);
@@ -524,7 +531,7 @@ public void testUnassignTask() {
         Metadata.Builder metadata = Metadata.builder(clusterState.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build());
         clusterState = builder.metadata(metadata).nodes(nodes).build();
         setState(clusterService, clusterState);
-        PersistentTasksClusterService service = createService((params, currentState) ->
+        PersistentTasksClusterService service = createService((params, candidateNodes, currentState) ->
             new Assignment("_node_2", "test"));
         service.unassignPersistentTask(unassignedId, tasks.getLastAllocationId(), "unassignment test", ActionListener.wrap(
             task -> {
@@ -550,7 +557,7 @@ public void testUnassignNonExistentTask() {
         Metadata.Builder metadata = Metadata.builder(clusterState.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build());
         clusterState = builder.metadata(metadata).nodes(nodes).build();
         setState(clusterService, clusterState);
-        PersistentTasksClusterService service = createService((params, currentState) ->
+        PersistentTasksClusterService service = createService((params, candidateNodes, currentState) ->
             new Assignment("_node_2", "test"));
         service.unassignPersistentTask("missing-task", tasks.getLastAllocationId(), "unassignment test", ActionListener.wrap(
             task -> fail(),
@@ -558,6 +565,82 @@ public void testUnassignNonExistentTask() {
         ));
     }
 
+    public void testIsNodeShuttingDown() {
+        NodesShutdownMetadata nodesShutdownMetadata = new NodesShutdownMetadata(Collections.singletonMap("this_node",
+            SingleNodeShutdownMetadata.builder()
+                .setNodeId("this_node")
+                .setReason("shutdown for a unit test")
+                .setType(randomBoolean() ? SingleNodeShutdownMetadata.Type.REMOVE : SingleNodeShutdownMetadata.Type.RESTART)
+                .setStartedAtMillis(randomNonNegativeLong())
+                .build()));
+        ClusterState state = initialState();
+
+        state = ClusterState.builder(state)
+            .metadata(Metadata.builder(state.metadata())
+                .putCustom(NodesShutdownMetadata.TYPE, nodesShutdownMetadata)
+                .build())
+            .nodes(DiscoveryNodes.builder(state.nodes())
+                .add(new DiscoveryNode("_node_1", buildNewFakeTransportAddress(), Version.CURRENT))
+                .build())
+            .build();
+
+        assertThat(PersistentTasksClusterService.isNodeShuttingDown(state, "this_node"), equalTo(true));
+        assertThat(PersistentTasksClusterService.isNodeShuttingDown(state, "_node_1"), equalTo(false));
+    }
+
+    public void testTasksNotAssignedToShuttingDownNodes() {
+        ClusterState clusterState = initialState();
+        ClusterState.Builder builder = ClusterState.builder(clusterState);
+        PersistentTasksCustomMetadata.Builder tasks = PersistentTasksCustomMetadata.builder(
+            clusterState.metadata().custom(PersistentTasksCustomMetadata.TYPE));
+        DiscoveryNodes.Builder nodes = DiscoveryNodes.builder(clusterState.nodes());
+        addTestNodes(nodes, randomIntBetween(2, 10));
+        int numberOfTasks = randomIntBetween(20, 40);
+        for (int i = 0; i < numberOfTasks; i++) {
+            addTask(tasks, randomFrom("assign_me", "assign_one", "assign_based_on_non_cluster_state_condition"),
+                randomBoolean() ? null : "no_longer_exists");
+        }
+
+        Metadata.Builder metadata = Metadata.builder(clusterState.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build());
+        clusterState = builder.metadata(metadata).nodes(nodes).build();
+
+        // Now that we have a bunch of tasks that need to be assigned, let's
+        // mark half the nodes as shut down and make sure they do not have any
+        // tasks assigned
+        Collection<DiscoveryNode> allNodes = clusterState.nodes().getAllNodes();
+        Map<String, SingleNodeShutdownMetadata> shutdownMetadataMap = new HashMap<>();
+        allNodes.stream().limit(Math.floorDiv(allNodes.size(), 2)).forEach(node ->
+            shutdownMetadataMap.put(node.getId(), SingleNodeShutdownMetadata.builder()
+                .setNodeId(node.getId())
+                .setReason("shutdown for a unit test")
+                .setType(randomBoolean() ? SingleNodeShutdownMetadata.Type.REMOVE : SingleNodeShutdownMetadata.Type.RESTART)
+                .setStartedAtMillis(randomNonNegativeLong())
+                .build()));
+        logger.info("--> nodes marked as shutting down: {}", shutdownMetadataMap.keySet());
+
+        ClusterState shutdownState = ClusterState.builder(clusterState)
+            .metadata(Metadata.builder(clusterState.metadata())
+                .putCustom(NodesShutdownMetadata.TYPE, new NodesShutdownMetadata(shutdownMetadataMap))
+                .build())
+            .build();
+
+        logger.info("--> assigning after marking nodes as shutting down");
+        nonClusterStateCondition = randomBoolean();
+        clusterState = reassign(shutdownState);
+        PersistentTasksCustomMetadata tasksInProgress = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
+        assertThat(tasksInProgress, notNullValue());
+        Set<String> nodesWithTasks = tasksInProgress.tasks().stream()
+            .map(PersistentTask::getAssignment)
+            .map(Assignment::getExecutorNode)
+            .filter(Objects::nonNull)
+            .collect(Collectors.toSet());
+        Set<String> shutdownNodes = shutdownMetadataMap.keySet();
+
+        assertTrue("expected shut down nodes: " + shutdownNodes +
+            " to have no nodes in common with nodes assigned tasks: " + nodesWithTasks,
+            Sets.haveEmptyIntersection(shutdownNodes, nodesWithTasks));
+    }
+
     private ClusterService createRecheckTestClusterService(ClusterState initialState, boolean shouldSimulateFailure) {
         AtomicBoolean testFailureNextTime = new AtomicBoolean(shouldSimulateFailure);
         AtomicReference<ClusterState> state = new AtomicReference<>(initialState);
@@ -589,20 +672,28 @@ private void addTestNodes(DiscoveryNodes.Builder nodes, int nonLocalNodesCount)
     }
 
     private ClusterState reassign(ClusterState clusterState) {
-        PersistentTasksClusterService service = createService((params, currentState) -> {
+        PersistentTasksClusterService service = createService((params, candidateNodes, currentState) -> {
             TestParams testParams = (TestParams) params;
             switch (testParams.getTestParam()) {
                 case "assign_me":
-                    return randomNodeAssignment(currentState.nodes());
+                    logger.info("--> assigning task randomly from candidates [{}]",
+                        candidateNodes.stream().map(DiscoveryNode::getId).collect(Collectors.joining(",")));
+                    Assignment assignment = randomNodeAssignment(candidateNodes);
+                    logger.info("--> assigned task to {}", assignment);
+                    return assignment;
                 case "dont_assign_me":
+                    logger.info("--> not assigning task");
                     return NO_NODE_FOUND;
                 case "fail_me_if_called":
+                    logger.info("--> failing test from task assignment");
                     fail("the decision decider shouldn't be called on this task");
                     return null;
                 case "assign_one":
-                    return assignOnlyOneTaskAtATime(currentState);
+                    logger.info("--> assigning only a single task");
+                    return assignOnlyOneTaskAtATime(candidateNodes, currentState);
                 case "assign_based_on_non_cluster_state_condition":
-                    return assignBasedOnNonClusterStateCondition(currentState.nodes());
+                    logger.info("--> assigning based on non cluster state condition: {}", nonClusterStateCondition);
+                    return assignBasedOnNonClusterStateCondition(candidateNodes);
                 default:
                     fail("unknown param " + testParams.getTestParam());
             }
@@ -612,40 +703,37 @@ private ClusterState reassign(ClusterState clusterState) {
         return service.reassignTasks(clusterState);
     }
 
-    private Assignment assignOnlyOneTaskAtATime(ClusterState clusterState) {
+    private Assignment assignOnlyOneTaskAtATime(Collection<DiscoveryNode> candidateNodes, ClusterState clusterState) {
         DiscoveryNodes nodes = clusterState.nodes();
         PersistentTasksCustomMetadata tasksInProgress = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
         if (tasksInProgress.findTasks(TestPersistentTasksExecutor.NAME, task ->
                 "assign_one".equals(((TestParams) task.getParams()).getTestParam()) &&
                         nodes.nodeExists(task.getExecutorNode())).isEmpty()) {
-            return randomNodeAssignment(clusterState.nodes());
+            return randomNodeAssignment(candidateNodes);
         } else {
             return new Assignment(null, "only one task can be assigned at a time");
         }
     }
 
-    private Assignment assignBasedOnNonClusterStateCondition(DiscoveryNodes nodes) {
+    private Assignment assignBasedOnNonClusterStateCondition(Collection<DiscoveryNode> candidateNodes) {
         if (nonClusterStateCondition) {
-            return randomNodeAssignment(nodes);
+            return randomNodeAssignment(candidateNodes);
         } else {
             return new Assignment(null, "non-cluster state condition prevents assignment");
         }
     }
 
-    private Assignment randomNodeAssignment(DiscoveryNodes nodes) {
-        if (nodes.getNodes().isEmpty()) {
-            return NO_NODE_FOUND;
-        }
-        List<String> nodeList = new ArrayList<>();
-        for (ObjectCursor<String> node : nodes.getNodes().keys()) {
-            nodeList.add(node.value);
-        }
-        String node = randomFrom(nodeList);
-        if (node != null) {
-            return new Assignment(node, "test assignment");
-        } else {
+    private Assignment randomNodeAssignment(Collection<DiscoveryNode> nodes) {
+        if (nodes.isEmpty()) {
             return NO_NODE_FOUND;
         }
+        return Optional.ofNullable(randomFrom(nodes))
+            .map(node -> new Assignment(node.getId(), "test assignment"))
+            .orElse(NO_NODE_FOUND);
+    }
+
+    private Assignment randomNodeAssignment(DiscoveryNodes nodes) {
+        return randomNodeAssignment(nodes.getAllNodes());
     }
 
     private String dumpEvent(ClusterChangedEvent event) {
@@ -866,17 +954,19 @@ private void changeRoutingTable(Metadata.Builder metadata, RoutingTable.Builder
     }
 
     /** Creates a PersistentTasksClusterService with a single PersistentTasksExecutor implemented by a BiFunction **/
-    private <P extends PersistentTaskParams> PersistentTasksClusterService createService(final BiFunction<P, ClusterState, Assignment> fn) {
+    private <P extends PersistentTaskParams> PersistentTasksClusterService
+    createService(final TriFunction<P, Collection<DiscoveryNode>, ClusterState, Assignment> fn) {
         return createService(clusterService, fn);
     }
 
-    private <P extends PersistentTaskParams> PersistentTasksClusterService createService(ClusterService clusterService,
-                                                                                         final BiFunction<P, ClusterState, Assignment> fn) {
+    private <P extends PersistentTaskParams> PersistentTasksClusterService
+    createService(ClusterService clusterService,
+                  final TriFunction<P, Collection<DiscoveryNode>, ClusterState, Assignment> fn) {
         PersistentTasksExecutorRegistry registry = new PersistentTasksExecutorRegistry(
             singleton(new PersistentTasksExecutor<P>(TestPersistentTasksExecutor.NAME, null) {
                 @Override
-                public Assignment getAssignment(P params, ClusterState clusterState) {
-                    return fn.apply(params, clusterState);
+                public Assignment getAssignment(P params, Collection<DiscoveryNode> candidateNodes, ClusterState clusterState) {
+                    return fn.apply(params, candidateNodes, clusterState);
                 }
 
                 @Override
diff --git a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java
index 2611b2ab9d278..ce6eea7b7f0ba 100644
--- a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java
+++ b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java
@@ -54,6 +54,7 @@
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -301,14 +302,14 @@ public static void setNonClusterStateCondition(boolean nonClusterStateCondition)
         }
 
         @Override
-        public Assignment getAssignment(TestParams params, ClusterState clusterState) {
+        public Assignment getAssignment(TestParams params, Collection<DiscoveryNode> candidateNodes, ClusterState clusterState) {
             if (nonClusterStateCondition == false) {
                 return new Assignment(null, "non cluster state condition prevents assignment");
             }
             if (params == null || params.getExecutorNodeAttr() == null) {
-                return super.getAssignment(params, clusterState);
+                return super.getAssignment(params, candidateNodes, clusterState);
             } else {
-                DiscoveryNode executorNode = selectLeastLoadedNode(clusterState,
+                DiscoveryNode executorNode = selectLeastLoadedNode(clusterState, candidateNodes,
                         discoveryNode -> params.getExecutorNodeAttr().equals(discoveryNode.getAttributes().get("test_attr")));
                 if (executorNode != null) {
                     return new Assignment(executorNode.getId(), "test assignment");
diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
index f028b1af5ffdb..8d5804403558c 100644
--- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
+++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java
@@ -75,6 +75,7 @@
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -127,14 +128,16 @@ public void validate(ShardFollowTask params, ClusterState clusterState) {
     private static final Assignment NO_ASSIGNMENT = new Assignment(null, "no nodes found with data and remote cluster client roles");
 
     @Override
-    public Assignment getAssignment(final ShardFollowTask params, final ClusterState clusterState) {
+    public Assignment getAssignment(final ShardFollowTask params,
+                                    Collection<DiscoveryNode> candidateNodes,
+                                    final ClusterState clusterState) {
         DiscoveryNode selectedNode = selectLeastLoadedNode(
-            clusterState,
+            clusterState, candidateNodes,
             ((Predicate<DiscoveryNode>) DiscoveryNode::canContainData).and(DiscoveryNode::isRemoteClusterClient)
         );
         if (selectedNode == null) {
             // best effort as nodes before 7.8 might not be able to connect to remote clusters
-            selectedNode = selectLeastLoadedNode(clusterState,
+            selectedNode = selectLeastLoadedNode(clusterState, candidateNodes,
                 node -> node.canContainData() && node.getVersion().before(DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE_VERSION));
         }
         if (selectedNode == null) {
diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java
index ddcd73c89c23c..cc616dcf93590 100644
--- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java
+++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java
@@ -109,7 +109,8 @@ private void runAssignmentTest(
             nodesBuilder.add(node);
         }
         clusterStateBuilder.nodes(nodesBuilder);
-        final Assignment assignment = executor.getAssignment(mock(ShardFollowTask.class), clusterStateBuilder.build());
+        final Assignment assignment = executor.getAssignment(mock(ShardFollowTask.class),
+            clusterStateBuilder.nodes().getAllNodes(), clusterStateBuilder.build());
         consumer.accept(targetNode, assignment);
     }
 
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
index b397572a0c9b4..159d4e89a4f0a 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
@@ -79,6 +79,7 @@
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;
 
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -607,7 +608,9 @@ protected AllocatedPersistentTask createTask(
         }
 
         @Override
-        public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params, ClusterState clusterState) {
+        public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params,
+                                                                      Collection<DiscoveryNode> candidateNodes,
+                                                                      ClusterState clusterState) {
             boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed();
             Optional<PersistentTasksCustomMetadata.Assignment> optionalAssignment =
                 getPotentialAssignment(params, clusterState, isMemoryTrackerRecentlyRefreshed);
@@ -617,6 +620,7 @@ public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params,
             JobNodeSelector jobNodeSelector =
                 new JobNodeSelector(
                     clusterState,
+                    candidateNodes,
                     params.getId(),
                     MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
                     memoryTracker,
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java
index 849500c9112d3..0b0eb2e7cf443 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java
@@ -65,6 +65,7 @@
 import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -425,7 +426,11 @@ public StartDatafeedPersistentTasksExecutor(DatafeedManager datafeedManager, Ind
 
         @Override
         public PersistentTasksCustomMetadata.Assignment getAssignment(StartDatafeedAction.DatafeedParams params,
+                                                                      Collection<DiscoveryNode> candidateNodes,
                                                                       ClusterState clusterState) {
+            // 'candidateNodes' is not actually used here because the assignment for the task is
+            // already filtered elsewhere (JobNodeSelector), this is only finding the node a task
+            // has already been assigned to.
             return new DatafeedNodeSelector(clusterState, resolver, params.getDatafeedId(), params.getJobId(),
                     params.getDatafeedIndices(), params.getIndicesOptions()).selectNode();
         }
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java
index a4f3da6c58aa5..eb9dd124b0d54 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java
@@ -23,6 +23,7 @@
 import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Locale;
@@ -68,6 +69,7 @@ private static String createReason(String job, String node, String msg, Object..
     private final String jobId;
     private final String taskName;
     private final ClusterState clusterState;
+    private final Collection<DiscoveryNode> candidateNodes;
     private final MlMemoryTracker memoryTracker;
     private final Function<DiscoveryNode, String> nodeFilter;
     private final NodeLoadDetector nodeLoadDetector;
@@ -79,6 +81,7 @@ private static String createReason(String job, String node, String msg, Object..
      *                   be <code>null</code> if no such function is needed.
      */
     public JobNodeSelector(ClusterState clusterState,
+                           Collection<DiscoveryNode> candidateNodes,
                            String jobId,
                            String taskName,
                            MlMemoryTracker memoryTracker,
@@ -87,6 +90,7 @@ public JobNodeSelector(ClusterState clusterState,
         this.jobId = Objects.requireNonNull(jobId);
         this.taskName = Objects.requireNonNull(taskName);
         this.clusterState = Objects.requireNonNull(clusterState);
+        this.candidateNodes = Objects.requireNonNull(candidateNodes);
         this.memoryTracker = Objects.requireNonNull(memoryTracker);
         this.nodeLoadDetector = new NodeLoadDetector(Objects.requireNonNull(memoryTracker));
         this.maxLazyNodes = maxLazyNodes;
@@ -102,8 +106,7 @@ public Tuple<NativeMemoryCapacity, Long> perceivedCapacityAndMaxFreeMemory(int m
                                                                                boolean useAutoMemoryPercentage,
                                                                                int maxOpenJobs,
                                                                                boolean isMemoryTrackerRecentlyRefreshed) {
-        List<DiscoveryNode> capableNodes = clusterState.getNodes()
-            .mastersFirstStream()
+        List<DiscoveryNode> capableNodes = candidateNodes.stream()
             .filter(n -> this.nodeFilter.apply(n) == null)
             .collect(Collectors.toList());
         NativeMemoryCapacity currentCapacityForMl = MlAutoscalingDeciderService.currentScale(
@@ -150,7 +153,7 @@ public PersistentTasksCustomMetadata.Assignment selectNode(int dynamicMaxOpenJob
         long maxAvailableMemory = Long.MIN_VALUE;
         DiscoveryNode minLoadedNodeByCount = null;
         DiscoveryNode minLoadedNodeByMemory = null;
-        for (DiscoveryNode node : clusterState.getNodes()) {
+        for (DiscoveryNode node : candidateNodes) {
 
             // First check conditions that would rule out the node regardless of what other tasks are assigned to it
             String reason = nodeFilter.apply(node);
@@ -301,7 +304,7 @@ PersistentTasksCustomMetadata.Assignment considerLazyAssignment(PersistentTasksC
         assert currentAssignment.getExecutorNode() == null;
 
         int numMlNodes = 0;
-        for (DiscoveryNode node : clusterState.getNodes()) {
+        for (DiscoveryNode node : candidateNodes) {
             if (MachineLearning.isMlNode(node)) {
                 numMlNodes++;
             }
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java
index 415a04f5cf804..461b661411aaf 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java
@@ -16,6 +16,7 @@
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.persistent.AllocatedPersistentTask;
@@ -43,6 +44,7 @@
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
 import java.util.Optional;
@@ -77,7 +79,9 @@ public SnapshotUpgradeTaskExecutor(Settings settings,
     }
 
     @Override
-    public PersistentTasksCustomMetadata.Assignment getAssignment(SnapshotUpgradeTaskParams params, ClusterState clusterState) {
+    public PersistentTasksCustomMetadata.Assignment getAssignment(SnapshotUpgradeTaskParams params,
+                                                                  Collection<DiscoveryNode> candidateNodes,
+                                                                  ClusterState clusterState) {
         boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed();
         Optional<PersistentTasksCustomMetadata.Assignment> optionalAssignment =
             getPotentialAssignment(params, clusterState, isMemoryTrackerRecentlyRefreshed);
@@ -86,6 +90,7 @@ public PersistentTasksCustomMetadata.Assignment getAssignment(SnapshotUpgradeTas
         }
         JobNodeSelector jobNodeSelector = new JobNodeSelector(
             clusterState,
+            candidateNodes,
             params.getJobId(),
             MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
             memoryTracker,
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java
index 9c21c79f25554..66e68bcd60b33 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java
@@ -52,6 +52,7 @@
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -106,7 +107,7 @@ public OpenJobPersistentTasksExecutor(Settings settings,
     }
 
     @Override
-    public Assignment getAssignment(OpenJobAction.JobParams params, ClusterState clusterState) {
+    public Assignment getAssignment(OpenJobAction.JobParams params, Collection<DiscoveryNode> candidateNodes, ClusterState clusterState) {
         // If the task parameters do not have a job field then the job
         // was first opened on a pre v6.6 node and has not been migrated
         Job job = params.getJob();
@@ -119,8 +120,8 @@ public Assignment getAssignment(OpenJobAction.JobParams params, ClusterState clu
             return optionalAssignment.get();
         }
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(clusterState, params.getJobId(), MlTasks.JOB_TASK_NAME, memoryTracker,
-            job.allowLazyOpen() ? Integer.MAX_VALUE : maxLazyMLNodes, node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(clusterState, candidateNodes, params.getJobId(),
+            MlTasks.JOB_TASK_NAME, memoryTracker, job.allowLazyOpen() ? Integer.MAX_VALUE : maxLazyMLNodes, node -> nodeFilter(node, job));
         Assignment assignment = jobNodeSelector.selectNode(
             maxOpenJobs,
             maxConcurrentJobAllocations,
@@ -180,7 +181,7 @@ public void validate(OpenJobAction.JobParams params, ClusterState clusterState)
         validateJobAndId(jobId, job);
         // If we already know that we can't find an ml node because all ml nodes are running at capacity or
         // simply because there are no ml nodes in the cluster then we fail quickly here:
-        PersistentTasksCustomMetadata.Assignment assignment = getAssignment(params, clusterState);
+        PersistentTasksCustomMetadata.Assignment assignment = getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         if (assignment.equals(AWAITING_UPGRADE)) {
             throw makeCurrentlyBeingUpgradedException(logger, params.getJobId());
         }
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java
index c73f7a80b41ad..3ccc094b96ce7 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java
@@ -56,7 +56,7 @@ public void testGetAssignment_UpgradeModeIsEnabled() {
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build()))
                 .build();
 
-        Assignment assignment = executor.getAssignment(params, clusterState);
+        Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         assertThat(assignment.getExecutorNode(), is(nullValue()));
         assertThat(assignment.getExplanation(), is(equalTo("persistent task cannot be assigned while upgrade mode is enabled.")));
     }
@@ -70,7 +70,7 @@ public void testGetAssignment_NoNodes() {
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build()))
                 .build();
 
-        Assignment assignment = executor.getAssignment(params, clusterState);
+        Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         assertThat(assignment.getExecutorNode(), is(nullValue()));
         assertThat(assignment.getExplanation(), is(emptyString()));
     }
@@ -88,7 +88,7 @@ public void testGetAssignment_NoMlNodes() {
                     .add(createNode(2, false, Version.CURRENT)))
                 .build();
 
-        Assignment assignment = executor.getAssignment(params, clusterState);
+        Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         assertThat(assignment.getExecutorNode(), is(nullValue()));
         assertThat(
             assignment.getExplanation(),
@@ -114,7 +114,7 @@ public void testGetAssignment_MlNodesAreTooOld() {
                     .add(createNode(2, true, Version.V_7_9_2)))
                 .build();
 
-        Assignment assignment = executor.getAssignment(params, clusterState);
+        Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         assertThat(assignment.getExecutorNode(), is(nullValue()));
         assertThat(
             assignment.getExplanation(),
@@ -139,7 +139,7 @@ public void testGetAssignment_MlNodeIsNewerThanTheMlJobButTheAssignmentSuceeds()
                     .add(createNode(0, true, Version.V_7_10_0)))
                 .build();
 
-        Assignment assignment = executor.getAssignment(params, clusterState);
+        Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState);
         assertThat(assignment.getExecutorNode(), is(equalTo("_node_id0")));
         assertThat(assignment.getExplanation(), is(emptyString()));
     }
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
index 07d93323cf368..10a11b84eb141 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
@@ -26,9 +26,9 @@
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.ml.autoscaling.NativeMemoryCapacity;
 import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests;
-import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
 import org.junit.Before;
@@ -124,8 +124,8 @@ public void testSelectLeastLoadedMlNode_byCount() {
         jobBuilder.setJobVersion(Version.CURRENT);
 
         Job job = jobBuilder.build();
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
             2,
             30,
@@ -150,8 +150,8 @@ public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_maxCapacityCountLim
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date());
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
             2,
             maxMachineMemoryPercent,
@@ -177,7 +177,7 @@ public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_maxCapacityCount
 
         String dataFrameAnalyticsId = "data_frame_analytics_id1000";
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId,
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), dataFrameAnalyticsId,
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0,
             node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, createTaskParams(dataFrameAnalyticsId)));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
@@ -211,8 +211,8 @@ public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_maxCapacityMemoryLi
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date());
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
             2,
             maxMachineMemoryPercent,
@@ -241,7 +241,7 @@ public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_givenTaskHasNull
 
         String dataFrameAnalyticsId = "data_frame_analytics_id_new";
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId,
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), dataFrameAnalyticsId,
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0,
             node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, createTaskParams(dataFrameAnalyticsId)));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
@@ -268,8 +268,8 @@ public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_firstJobTooBigMemor
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date());
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker,
-            0, node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
             2,
             maxMachineMemoryPercent,
@@ -303,7 +303,7 @@ public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_maxCapacityMemor
 
         String dataFrameAnalyticsId = "data_frame_analytics_id1000";
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId,
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), dataFrameAnalyticsId,
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0,
             node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, createTaskParams(dataFrameAnalyticsId)));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(
@@ -336,7 +336,7 @@ public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_firstJobTooBigMe
 
         String dataFrameAnalyticsId = "data_frame_analytics_id1000";
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId,
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), dataFrameAnalyticsId,
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0,
             node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, createTaskParams(dataFrameAnalyticsId)));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(
@@ -373,8 +373,8 @@ public void testSelectLeastLoadedMlNode_noMlNodes() {
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", JOB_MEMORY_REQUIREMENT).build(new Date());
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(
             20,
             2,
@@ -416,9 +416,8 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() {
         Job job6 = BaseMlIntegTestCase.createFareQuoteJob("job_id6", JOB_MEMORY_REQUIREMENT).build(new Date());
 
         ClusterState cs = csBuilder.build();
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, job6.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
-            node -> nodeFilter(node, job6));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job6.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job6));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(
             10,
             2,
@@ -437,8 +436,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() {
         cs = csBuilder.build();
 
         Job job7 = BaseMlIntegTestCase.createFareQuoteJob("job_id7", JOB_MEMORY_REQUIREMENT).build(new Date());
-        jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
+        jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
             node -> nodeFilter(node, job7));
         result = jobNodeSelector.selectNode(10,
             2,
@@ -457,7 +455,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() {
         csBuilder = ClusterState.builder(cs);
         csBuilder.metadata(Metadata.builder(cs.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks));
         cs = csBuilder.build();
-        jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
+        jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
             node -> nodeFilter(node, job7));
         result = jobNodeSelector.selectNode(10, 2, 30, MAX_JOB_BYTES, isMemoryTrackerRecentlyRefreshed, false);
         assertNull("no node selected, because stale task", result.getExecutorNode());
@@ -470,8 +468,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() {
         csBuilder = ClusterState.builder(cs);
         csBuilder.metadata(Metadata.builder(cs.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks));
         cs = csBuilder.build();
-        jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
+        jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
             node -> nodeFilter(node, job7));
         result = jobNodeSelector.selectNode(10, 2, 30, MAX_JOB_BYTES, isMemoryTrackerRecentlyRefreshed, false);
         assertNull("no node selected, because null state", result.getExecutorNode());
@@ -513,9 +510,8 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob()
         Job job7 = BaseMlIntegTestCase.createFareQuoteJob("job_id7", JOB_MEMORY_REQUIREMENT).build(new Date());
 
         // Allocation won't be possible if the stale failed job is treated as opening
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
-            node -> nodeFilter(node, job7));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job7.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job7));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
             2,
             30,
@@ -532,8 +528,7 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob()
         csBuilder.metadata(Metadata.builder(cs.metadata()).putCustom(PersistentTasksCustomMetadata.TYPE, tasks));
         cs = csBuilder.build();
         Job job8 = BaseMlIntegTestCase.createFareQuoteJob("job_id8", JOB_MEMORY_REQUIREMENT).build(new Date());
-        jobNodeSelector = new JobNodeSelector(cs, job8.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
+        jobNodeSelector = new JobNodeSelector(cs, cs.nodes().getAllNodes(), job8.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
             node -> nodeFilter(node, job8));
         result = jobNodeSelector.selectNode(10, 2, 30, MAX_JOB_BYTES, isMemoryTrackerRecentlyRefreshed, false);
         assertNull("no node selected, because OPENING state", result.getExecutorNode());
@@ -567,9 +562,8 @@ public void testSelectLeastLoadedMlNode_noCompatibleJobTypeNodes() {
         cs.nodes(nodes);
         metadata.putCustom(PersistentTasksCustomMetadata.TYPE, tasks);
         cs.metadata(metadata);
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
             2,
             30,
@@ -605,7 +599,7 @@ public void testSelectLeastLoadedMlNode_noNodesMatchingModelSnapshotMinVersion()
         cs.nodes(nodes);
         metadata.putCustom(PersistentTasksCustomMetadata.TYPE, tasks);
         cs.metadata(metadata);
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(),
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(),
             MlTasks.JOB_TASK_NAME, memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
             2,
@@ -640,8 +634,45 @@ public void testSelectLeastLoadedMlNode_jobWithRules() {
         cs.metadata(metadata);
 
         Job job = jobWithRules("job_with_rules");
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
+        PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
+            2,
+            30,
+            MAX_JOB_BYTES,
+            isMemoryTrackerRecentlyRefreshed,
+            false);
+        assertNotNull(result.getExecutorNode());
+    }
+
+    public void testSelectMlNodeOnlyOutOfCandidates() {
+        Map<String, String> nodeAttr = new HashMap<>();
+        nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10");
+        nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000");
+        DiscoveryNodes nodes = DiscoveryNodes.builder()
+            .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
+                nodeAttr, Collections.emptySet(), Version.CURRENT))
+            .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301),
+                nodeAttr, Collections.emptySet(), Version.CURRENT))
+            .build();
+
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        OpenJobPersistentTasksExecutorTests.addJobTask("job_with_rules", "_node_id1", null, tasksBuilder);
+        PersistentTasksCustomMetadata tasks = tasksBuilder.build();
+
+        ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name"));
+        Metadata.Builder metadata = Metadata.builder();
+        cs.nodes(nodes);
+        metadata.putCustom(PersistentTasksCustomMetadata.TYPE, tasks);
+        cs.metadata(metadata);
+
+        DiscoveryNode candidate = nodes.getNodes().get(randomBoolean() ? "_node_id1" : "_node_id2");
+
+        Job job = jobWithRules("job_with_rules");
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(),
+            Collections.singletonList(candidate),
+            job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(10,
             2,
             30,
@@ -649,6 +680,7 @@ public void testSelectLeastLoadedMlNode_jobWithRules() {
             isMemoryTrackerRecentlyRefreshed,
             false);
         assertNotNull(result.getExecutorNode());
+        assertThat(result.getExecutorNode(), equalTo(candidate.getId()));
     }
 
     public void testConsiderLazyAssignmentWithNoLazyNodes() {
@@ -663,8 +695,8 @@ public void testConsiderLazyAssignmentWithNoLazyNodes() {
         cs.nodes(nodes);
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date());
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result =
             jobNodeSelector.considerLazyAssignment(new PersistentTasksCustomMetadata.Assignment(null, "foo"));
         assertEquals("foo", result.getExplanation());
@@ -683,8 +715,8 @@ public void testConsiderLazyAssignmentWithLazyNodes() {
         cs.nodes(nodes);
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date());
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker,
-            randomIntBetween(1, 3), node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, randomIntBetween(1, 3), node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result =
             jobNodeSelector.considerLazyAssignment(new PersistentTasksCustomMetadata.Assignment(null, "foo"));
         assertEquals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.getExplanation(), result.getExplanation());
@@ -706,9 +738,8 @@ public void testMaximumPossibleNodeMemoryTooSmall() {
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", ByteSizeValue.ofMb(10)).build(new Date());
         when(memoryTracker.getJobMemoryRequirement(anyString(), eq("job_id1000"))).thenReturn(1000L);
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker,
-            randomIntBetween(1, 3),
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, randomIntBetween(1, 3), node -> nodeFilter(node, job));
         PersistentTasksCustomMetadata.Assignment result = jobNodeSelector.selectNode(maxRunningJobsPerNode,
             2,
             maxMachineMemoryPercent,
@@ -769,8 +800,8 @@ public void testPerceivedCapacityAndMaxFreeMemory() {
 
         Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", JOB_MEMORY_REQUIREMENT).build(new Date());
 
-        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0,
-            node -> nodeFilter(node, job));
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), cs.nodes().getAllNodes(), job.getId(), MlTasks.JOB_TASK_NAME,
+            memoryTracker, 0, node -> nodeFilter(node, job));
         Tuple<NativeMemoryCapacity, Long> capacityAndFreeMemory = jobNodeSelector.perceivedCapacityAndMaxFreeMemory(
             10,
             false,
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java
index 00ee28937546b..107d6a83db3cd 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java
@@ -141,7 +141,7 @@ public void testGetAssignment_GivenJobThatRequiresMigration() {
         OpenJobPersistentTasksExecutor executor = createExecutor(Settings.EMPTY);
 
         OpenJobAction.JobParams params = new OpenJobAction.JobParams("missing_job_field");
-        assertEquals(AWAITING_MIGRATION, executor.getAssignment(params, mock(ClusterState.class)));
+        assertEquals(AWAITING_MIGRATION, executor.getAssignment(params, Collections.emptyList(), mock(ClusterState.class)));
     }
 
     // An index being unavailable should take precedence over waiting for a lazy node
@@ -162,7 +162,7 @@ public void testGetAssignment_GivenUnavailableIndicesWithLazyNode() {
         params.setJob(mock(Job.class));
         assertEquals("Not opening [unavailable_index_with_lazy_node], " +
                 "because not all primary shards are active for the following indices [.ml-state]",
-            executor.getAssignment(params, csBuilder.build()).getExplanation());
+            executor.getAssignment(params, csBuilder.nodes().getAllNodes(), csBuilder.build()).getExplanation());
     }
 
     public void testGetAssignment_GivenLazyJobAndNoGlobalLazyNodes() {
@@ -180,7 +180,8 @@ public void testGetAssignment_GivenLazyJobAndNoGlobalLazyNodes() {
         when(job.allowLazyOpen()).thenReturn(true);
         OpenJobAction.JobParams params = new OpenJobAction.JobParams("lazy_job");
         params.setJob(job);
-        PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment(params, csBuilder.build());
+        PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment(params,
+            csBuilder.nodes().getAllNodes(), csBuilder.build());
         assertNotNull(assignment);
         assertNull(assignment.getExecutorNode());
         assertEquals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.getExplanation(), assignment.getExplanation());
@@ -197,7 +198,8 @@ public void testGetAssignment_GivenResetInProgress() {
         Job job = mock(Job.class);
         OpenJobAction.JobParams params = new OpenJobAction.JobParams("job_during_reset");
         params.setJob(job);
-        PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment(params, csBuilder.build());
+        PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment(params,
+            csBuilder.nodes().getAllNodes(), csBuilder.build());
         assertNotNull(assignment);
         assertNull(assignment.getExecutorNode());
         assertEquals(MlTasks.RESET_IN_PROGRESS.getExplanation(), assignment.getExplanation());
diff --git a/x-pack/plugin/shutdown/build.gradle b/x-pack/plugin/shutdown/build.gradle
index d323fb6140424..dc8f29c7c17bd 100644
--- a/x-pack/plugin/shutdown/build.gradle
+++ b/x-pack/plugin/shutdown/build.gradle
@@ -1,4 +1,5 @@
 apply plugin: 'elasticsearch.esplugin'
+apply plugin: 'elasticsearch.internal-cluster-test'
 
 esplugin {
     name 'x-pack-shutdown'
@@ -14,3 +15,12 @@ dependencies {
 }
 
 addQaCheckDependencies()
+
+testClusters.all {
+  testDistribution = 'default'
+  setting 'xpack.security.enabled', 'true'
+  setting 'xpack.license.self_generated.type', 'trial'
+  keystore 'bootstrap.password', 'x-pack-test-password'
+  user username: "x_pack_rest_user", password: "x-pack-test-password"
+  systemProperty 'es.shutdown_feature_flag_enabled', 'true'
+}
diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java
new file mode 100644
index 0000000000000..6159f1e86ffc6
--- /dev/null
+++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java
@@ -0,0 +1,248 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.shutdown;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.admin.cluster.node.info.NodeInfo;
+import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.ClusterChangedEvent;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateListener;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.settings.SettingsModule;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.env.Environment;
+import org.elasticsearch.env.NodeEnvironment;
+import org.elasticsearch.persistent.AllocatedPersistentTask;
+import org.elasticsearch.persistent.PersistentTaskParams;
+import org.elasticsearch.persistent.PersistentTaskState;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
+import org.elasticsearch.persistent.PersistentTasksExecutor;
+import org.elasticsearch.persistent.PersistentTasksService;
+import org.elasticsearch.plugins.PersistentTaskPlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.repositories.RepositoriesService;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.watcher.ResourceWatcherService;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.not;
+
+/**
+ * This class is for testing that when shutting down a node, persistent tasks
+ * are not assigned to that node.
+ */
+@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0, transportClientRatio = 0)
+public class NodeShutdownTasksIT extends ESIntegTestCase {
+
+    private static final Logger logger = LogManager.getLogger(NodeShutdownTasksIT.class);
+    private static final AtomicBoolean startTask = new AtomicBoolean(false);
+    private static final AtomicBoolean taskCompleted = new AtomicBoolean(false);
+    private static final AtomicReference<Collection<DiscoveryNode>> candidates = new AtomicReference<>(null);
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return Arrays.asList(ShutdownEnabledPlugin.class, TaskPlugin.class);
+    }
+
+    public void testTasksAreNotAssignedToShuttingDownNode() throws Exception {
+        // Start two nodes, one will be marked as shutting down
+        final String node1 = internalCluster().startNode(Settings.EMPTY);
+        final String node2 = internalCluster().startNode(Settings.EMPTY);
+
+        final String shutdownNode;
+        final String candidateNode;
+        NodesInfoResponse nodes = client().admin().cluster().prepareNodesInfo().clear().get();
+        final String node1Id = nodes.getNodes()
+            .stream()
+            .map(NodeInfo::getNode)
+            .filter(node -> node.getName().equals(node1))
+            .map(DiscoveryNode::getId)
+            .findFirst()
+            .get();
+        final String node2Id = nodes.getNodes()
+            .stream()
+            .map(NodeInfo::getNode)
+            .filter(node -> node.getName().equals(node2))
+            .map(DiscoveryNode::getId)
+            .findFirst()
+            .get();
+
+        if (randomBoolean()) {
+            shutdownNode = node1Id;
+            candidateNode = node2Id;
+        } else {
+            shutdownNode = node2Id;
+            candidateNode = node1Id;
+        }
+        logger.info("--> node {} will be shut down, {} will remain", shutdownNode, candidateNode);
+
+        // Mark the node as shutting down
+        client().execute(
+            PutShutdownNodeAction.INSTANCE,
+            new PutShutdownNodeAction.Request(shutdownNode, SingleNodeShutdownMetadata.Type.REMOVE, "removal for testing")
+        ).get();
+
+        // Tell the persistent task executor it can start allocating the task
+        startTask.set(true);
+        // Issue a new cluster state update to force task assignment
+        client().admin().cluster().prepareReroute().get();
+        // Wait until the task has been assigned to a node
+        assertBusy(() -> assertNotNull("expected to have candidate nodes chosen for task", candidates.get()));
+        // Check that the node that is not shut down is the only candidate
+        assertThat(candidates.get().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()), contains(candidateNode));
+        assertThat(candidates.get().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()), not(contains(shutdownNode)));
+    }
+
+    public static class ShutdownEnabledPlugin extends ShutdownPlugin {
+        @Override
+        public boolean isEnabled() {
+            return true;
+        }
+    }
+
+    public static class TaskPlugin extends Plugin implements PersistentTaskPlugin {
+
+        TaskExecutor taskExecutor;
+
+        @Override
+        public Collection<Object> createComponents(
+            Client client,
+            ClusterService clusterService,
+            ThreadPool threadPool,
+            ResourceWatcherService resourceWatcherService,
+            ScriptService scriptService,
+            NamedXContentRegistry xContentRegistry,
+            Environment environment,
+            NodeEnvironment nodeEnvironment,
+            NamedWriteableRegistry namedWriteableRegistry,
+            IndexNameExpressionResolver indexNameExpressionResolver,
+            Supplier<RepositoriesService> repositoriesServiceSupplier
+        ) {
+            taskExecutor = new TaskExecutor(client, clusterService, threadPool);
+            return Collections.singletonList(taskExecutor);
+        }
+
+        @Override
+        public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(
+            ClusterService clusterService,
+            ThreadPool threadPool,
+            Client client,
+            SettingsModule settingsModule,
+            IndexNameExpressionResolver expressionResolver
+        ) {
+            return Collections.singletonList(taskExecutor);
+        }
+
+        @Override
+        public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
+            return Collections.singletonList(
+                new NamedWriteableRegistry.Entry(PersistentTaskParams.class, "task_name", TestTaskParams::new)
+            );
+        }
+    }
+
+    public static class TaskExecutor extends PersistentTasksExecutor<TestTaskParams> implements ClusterStateListener {
+
+        private final PersistentTasksService persistentTasksService;
+
+        protected TaskExecutor(Client client, ClusterService clusterService, ThreadPool threadPool) {
+            super("task_name", ThreadPool.Names.GENERIC);
+            persistentTasksService = new PersistentTasksService(clusterService, threadPool, client);
+            clusterService.addListener(this);
+        }
+
+        @Override
+        public PersistentTasksCustomMetadata.Assignment getAssignment(
+            TestTaskParams params,
+            Collection<DiscoveryNode> candidateNodes,
+            ClusterState clusterState
+        ) {
+            candidates.set(candidateNodes);
+            return super.getAssignment(params, candidateNodes, clusterState);
+        }
+
+        @Override
+        protected void nodeOperation(AllocatedPersistentTask task, TestTaskParams params, PersistentTaskState state) {
+            logger.info("--> executing the task");
+            taskCompleted.compareAndSet(false, true);
+        }
+
+        private void startTask() {
+            logger.info("--> sending start request");
+            persistentTasksService.sendStartRequest("task_id", "task_name", new TestTaskParams(), ActionListener.wrap(r -> {}, e -> {
+                if (e instanceof ResourceAlreadyExistsException == false) {
+                    logger.error("failed to create task", e);
+                    fail("failed to create task");
+                }
+            }));
+        }
+
+        @Override
+        public void clusterChanged(ClusterChangedEvent event) {
+            // Check if it's true, setting it to false if we are going to start task
+            if (startTask.compareAndSet(true, false)) {
+                startTask();
+            }
+        }
+    }
+
+    public static class TestTaskParams implements PersistentTaskParams {
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.endObject();
+            return builder;
+        }
+
+        public TestTaskParams() {}
+
+        public TestTaskParams(StreamInput in) {}
+
+        @Override
+        public String getWriteableName() {
+            return "task_name";
+        }
+
+        @Override
+        public Version getMinimalSupportedVersion() {
+            return Version.CURRENT;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+
+        }
+    }
+}
diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/ShutdownPlugin.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/ShutdownPlugin.java
index 3cc74c4de604c..62df425a71cbf 100644
--- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/ShutdownPlugin.java
+++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/ShutdownPlugin.java
@@ -29,7 +29,7 @@ public class ShutdownPlugin extends Plugin implements ActionPlugin {
 
     public static final boolean SHUTDOWN_FEATURE_FLAG_ENABLED = "true".equals(System.getProperty("es.shutdown_feature_flag_enabled"));
 
-    public static boolean isEnabled() {
+    public boolean isEnabled() {
         return SHUTDOWN_FEATURE_FLAG_ENABLED;
     }
 
diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java
index ad85dc1bf1e17..178b497a2cbdd 100644
--- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java
+++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java
@@ -52,7 +52,7 @@ protected void masterOperation(
     ) throws Exception {
         { // This block solely to ensure this NodesShutdownMetadata isn't accidentally used in the cluster state update task below
             NodesShutdownMetadata nodesShutdownMetadata = state.metadata().custom(NodesShutdownMetadata.TYPE);
-            if (nodesShutdownMetadata.getAllNodeMetdataMap().get(request.getNodeId()) == null) {
+            if (nodesShutdownMetadata.getAllNodeMetadataMap().get(request.getNodeId()) == null) {
                 throw new IllegalArgumentException("node [" + request.getNodeId() + "] is not currently shutting down");
             }
         }
diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java
index 6664876816da1..e5ba9162da3fd 100644
--- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java
+++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java
@@ -64,9 +64,9 @@ protected void masterOperation(
         if (nodesShutdownMetadata == null) {
             response = new GetShutdownStatusAction.Response(new ArrayList<>());
         } else if (request.getNodeIds().length == 0) {
-            response = new GetShutdownStatusAction.Response(new ArrayList<>(nodesShutdownMetadata.getAllNodeMetdataMap().values()));
+            response = new GetShutdownStatusAction.Response(new ArrayList<>(nodesShutdownMetadata.getAllNodeMetadataMap().values()));
         } else {
-            Map<String, SingleNodeShutdownMetadata> nodeShutdownMetadataMap = nodesShutdownMetadata.getAllNodeMetdataMap();
+            Map<String, SingleNodeShutdownMetadata> nodeShutdownMetadataMap = nodesShutdownMetadata.getAllNodeMetadataMap();
             final List<SingleNodeShutdownMetadata> shutdownStatuses = Arrays.stream(request.getNodeIds())
                 .map(nodeShutdownMetadataMap::get)
                 .filter(Objects::nonNull)
diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java
index 761f6bb4a3ade..db8e1f553f8b9 100644
--- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java
+++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java
@@ -62,7 +62,7 @@ public ClusterState execute(ClusterState currentState) {
                     }
 
                     // Verify that there's not already a shutdown metadata for this node
-                    if (Objects.nonNull(currentShutdownMetadata.getAllNodeMetdataMap().get(request.getNodeId()))) {
+                    if (Objects.nonNull(currentShutdownMetadata.getAllNodeMetadataMap().get(request.getNodeId()))) {
                         throw new IllegalArgumentException("node [" + request.getNodeId() + "] is already shutting down");
                     }
 
diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java
index 25e2ac74796ab..43f6cbe51069d 100644
--- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java
+++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java
@@ -50,6 +50,7 @@
 import org.elasticsearch.xpack.transform.transforms.pivot.SchemaUtil;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
@@ -94,7 +95,9 @@ public TransformPersistentTasksExecutor(
     }
 
     @Override
-    public PersistentTasksCustomMetadata.Assignment getAssignment(TransformTaskParams params, ClusterState clusterState) {
+    public PersistentTasksCustomMetadata.Assignment getAssignment(TransformTaskParams params,
+                                                                  Collection<DiscoveryNode> candidateNodes,
+                                                                  ClusterState clusterState) {
         if (TransformMetadata.getTransformMetadata(clusterState).isResetMode()) {
             return new PersistentTasksCustomMetadata.Assignment(null,
                 "Transform task will not be assigned as a feature reset is in progress.");
@@ -123,6 +126,7 @@ public PersistentTasksCustomMetadata.Assignment getAssignment(TransformTaskParam
 
         DiscoveryNode discoveryNode = selectLeastLoadedNode(
             clusterState,
+            candidateNodes,
             node -> node.getVersion().onOrAfter(Version.V_7_7_0)
                 ? nodeCanRunThisTransform(node, params.getVersion(), params.requiresRemote(), null)
                 : nodeCanRunThisTransformPre77(node, params.getVersion(), null)
diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java
index 4c8925a17f458..73094f4af1d6c 100644
--- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java
+++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java
@@ -60,15 +60,18 @@ public void testNodeVersionAssignment() {
         TransformPersistentTasksExecutor executor = buildTaskExecutor();
 
         assertThat(
-            executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, true), cs).getExecutorNode(),
+            executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, true),
+                cs.nodes().getAllNodes(), cs).getExecutorNode(),
             equalTo("current-data-node-with-1-tasks")
         );
         assertThat(
-            executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false), cs).getExecutorNode(),
+            executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false),
+                cs.nodes().getAllNodes(), cs).getExecutorNode(),
             equalTo("current-data-node-with-0-tasks-transform-remote-disabled")
         );
         assertThat(
-            executor.getAssignment(new TransformTaskParams("new-old-task-id", Version.V_7_5_0, null, true), cs).getExecutorNode(),
+            executor.getAssignment(new TransformTaskParams("new-old-task-id", Version.V_7_5_0, null, true),
+                cs.nodes().getAllNodes(), cs).getExecutorNode(),
             equalTo("past-data-node-1")
         );
     }
@@ -79,7 +82,8 @@ public void testNodeAssignmentProblems() {
         ClusterState cs = buildClusterState(nodes);
         TransformPersistentTasksExecutor executor = buildTaskExecutor();
 
-        Assignment assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false), cs);
+        Assignment assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false),
+            cs.nodes().getAllNodes(), cs);
         assertNull(assignment.getExecutorNode());
         assertThat(
             assignment.getExplanation(),
@@ -91,7 +95,8 @@ public void testNodeAssignmentProblems() {
         cs = buildClusterState(nodes);
         executor = buildTaskExecutor();
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false),
+            cs.nodes().getAllNodes(), cs);
         assertNotNull(assignment.getExecutorNode());
         assertThat(assignment.getExecutorNode(), equalTo("dedicated-transform-node"));
 
@@ -100,7 +105,8 @@ public void testNodeAssignmentProblems() {
         cs = buildClusterState(nodes);
         executor = buildTaskExecutor();
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_7_0, null, false), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_7_0, null, false),
+            cs.nodes().getAllNodes(), cs);
         assertNull(assignment.getExecutorNode());
         assertThat(
             assignment.getExplanation(),
@@ -113,7 +119,8 @@ public void testNodeAssignmentProblems() {
             )
         );
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, false), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, false),
+            cs.nodes().getAllNodes(), cs);
         assertNotNull(assignment.getExecutorNode());
         assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1"));
 
@@ -122,7 +129,8 @@ public void testNodeAssignmentProblems() {
         cs = buildClusterState(nodes);
         executor = buildTaskExecutor();
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true),
+            cs.nodes().getAllNodes(), cs);
         assertNull(assignment.getExecutorNode());
         assertThat(
             assignment.getExplanation(),
@@ -134,7 +142,8 @@ public void testNodeAssignmentProblems() {
             )
         );
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false),
+            cs.nodes().getAllNodes(), cs);
         assertNotNull(assignment.getExecutorNode());
         assertThat(assignment.getExecutorNode(), equalTo("current-data-node-with-0-tasks-transform-remote-disabled"));
 
@@ -143,7 +152,8 @@ public void testNodeAssignmentProblems() {
         cs = buildClusterState(nodes);
         executor = buildTaskExecutor();
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true),
+            cs.nodes().getAllNodes(), cs);
         assertNull(assignment.getExecutorNode());
         assertThat(
             assignment.getExplanation(),
@@ -161,7 +171,8 @@ public void testNodeAssignmentProblems() {
         cs = buildClusterState(nodes);
         executor = buildTaskExecutor();
 
-        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true), cs);
+        assignment = executor.getAssignment(new TransformTaskParams("new-task-id", Version.V_7_5_0, null, true),
+            cs.nodes().getAllNodes(), cs);
         assertNotNull(assignment.getExecutorNode());
         assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1"));
     }
@@ -254,7 +265,8 @@ public void testDoNotSelectOldNodes() {
         );
 
         // old-data-node-1 prevents assignment
-        assertNull(executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false), cs).getExecutorNode());
+        assertNull(executor.getAssignment(new TransformTaskParams("new-task-id", Version.CURRENT, null, false),
+                cs.nodes().getAllNodes(), cs).getExecutorNode());
 
         // remove the old 7.2 node
         nodes = DiscoveryNodes.builder()
@@ -279,7 +291,8 @@ public void testDoNotSelectOldNodes() {
         cs = ClusterState.builder(cs).nodes(nodes).build();
 
         assertThat(
-            executor.getAssignment(new TransformTaskParams("new-old-task-id", Version.V_7_2_0, null, false), cs).getExecutorNode(),
+            executor.getAssignment(new TransformTaskParams("new-old-task-id", Version.V_7_2_0, null, false),
+                    cs.nodes().getAllNodes(), cs).getExecutorNode(),
             equalTo("current-data-node-with-1-task")
         );
     }