diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a029e29b9e8..85def1c4bbf 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -641,14 +641,14 @@ def test_workerstate_flight_to_flight(ws): assert ws.tasks["x"].state == "flight" -def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task): +def test_workerstate_executing_success_to_fetch(ws_with_running_task): """Test state loops: - - executing -> cancelled -> resumed(fetch) -> memory - - executing -> long-running -> cancelled -> resumed(fetch) -> memory + - executing -> cancelled -> resumed(fetch) + - executing -> long-running -> cancelled -> resumed(fetch) The task execution later terminates successfully. - Test that the task is never fetched and that dependents are unblocked. + Test that the task completion is ignored and the task transitions to flight See also: test_workerstate_executing_failure_to_fetch """ @@ -660,14 +660,11 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task): ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"), ) assert instructions == [ - TaskFinishedMsg.match(key="x", stimulus_id="s3"), - Execute(key="y", stimulus_id="s3"), + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s2") ] - assert ws.tasks["x"].state == "memory" - assert ws.data["x"] == 123 + assert ws.tasks["x"].state == "flight" -@pytest.mark.xfail(reason="distributed#6689") def test_workerstate_executing_failure_to_fetch(ws_with_running_task): """Test state loops: @@ -692,7 +689,7 @@ def test_workerstate_executing_failure_to_fetch(ws_with_running_task): ExecuteFailureEvent.dummy("x", stimulus_id="s3"), ) assert instructions == [ - GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s3") + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s2") ] assert ws.tasks["x"].state == "flight" diff --git a/distributed/tests/test_reschedule.py b/distributed/tests/test_reschedule.py index 2fa53162825..3f9d48c1c63 100644 --- a/distributed/tests/test_reschedule.py +++ b/distributed/tests/test_reschedule.py @@ -13,6 +13,7 @@ from distributed import Event, Reschedule, get_worker, secede, wait from distributed.utils_test import captured_logger, gen_cluster, slowinc from distributed.worker_state_machine import ( + ComputeTaskEvent, FreeKeysEvent, RescheduleEvent, RescheduleMsg, @@ -113,7 +114,7 @@ def test_cancelled_reschedule_worker_state(ws_with_running_task): instructions = ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1")) assert not instructions - assert ws.tasks["x"].state == "cancelled" + assert ws.tasks["x"].state == "released" assert ws.available_resources == {"R": 0} instructions = ws.handle_stimulus(RescheduleEvent(key="x", stimulus_id="s2")) @@ -129,3 +130,32 @@ def test_reschedule_releases(ws_with_running_task): assert instructions == [RescheduleMsg(stimulus_id="s1", key="x")] assert ws.available_resources == {"R": 1} assert "x" not in ws.tasks + + +def test_reschedule_cancelled(ws_with_running_task): + """Test state loop: + + executing -> cancelled -> rescheduled + """ + ws = ws_with_running_task + instructions = ws.handle_stimulus( + FreeKeysEvent(keys=["x"], stimulus_id="s1"), + RescheduleEvent(key="x", stimulus_id="s2"), + ) + assert not instructions + assert "x" not in ws.tasks + + +def test_reschedule_resumed(ws_with_running_task): + """Test state loop: + + executing -> cancelled -> resumed(waiting) -> executing -> rescheduled + """ + ws = ws_with_running_task + instructions = ws.handle_stimulus( + FreeKeysEvent(keys=["x"], stimulus_id="s1"), + ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 1}), + RescheduleEvent(key="x", stimulus_id="s3"), + ) + assert instructions == [RescheduleMsg(key="x", stimulus_id="s3")] + assert "x" not in ws.tasks diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index cc11fce0448..173e2589877 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -14,6 +14,7 @@ from distributed.worker_state_machine import ( ComputeTaskEvent, Execute, + ExecuteFailureEvent, ExecuteSuccessEvent, FreeKeysEvent, TaskFinishedMsg, @@ -508,3 +509,46 @@ async def test_resources_from_python_override_config(c, s, a, b): info = c.scheduler_info() for worker in [a, b]: assert info["workers"][worker.address]["resources"] == {"my_resources": 10} + + +@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent]) +def test_cancel_with_resources(ws_with_running_task, done_ev_cls): + ws = ws_with_running_task + assert ws.available_resources == {"R": 0} + ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1")) + assert ws.tasks["x"].state == "released" + assert ws.available_resources == {"R": 0} + ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s2")) + assert "x" not in ws.tasks + assert ws.available_resources == {"R": 1} + + +@pytest.mark.parametrize( + "done_ev_cls,done_status", + [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")], +) +def test_resume_with_different_resources(ws, done_ev_cls, done_status): + """A task is cancelled and then resumed to the same state, but with different + resources. This is actually possible in case of manual cancellation from the client, + followed by resubmit. + """ + ws.total_resources = {"R": 1} + ws.available_resources = {"R": 1} + + ws.handle_stimulus( + ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.2}) + ) + assert ws.available_resources == {"R": 0.8} + + ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s2")) + assert ws.tasks["x"].state == "released" + assert ws.available_resources == {"R": 0.8} + + instructions = ws.handle_stimulus( + ComputeTaskEvent.dummy("x", stimulus_id="s3", resource_restrictions={"R": 0.3}) + ) + assert not instructions + assert ws.available_resources == {"R": 0.8} + ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s4")) + assert ws.tasks["x"].state == done_status + assert ws.available_resources == {"R": 1} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 7b044d031f5..340223184d0 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1330,7 +1330,7 @@ def test_steal_worker_state(ws_with_running_task): ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1")) assert ws.available_resources == {"R": 0} - assert ws.tasks["x"].state == "cancelled" + assert ws.tasks["x"].state == "released" instructions = ws.handle_stimulus(ExecuteSuccessEvent.dummy("x", stimulus_id="s2")) assert not instructions diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3d7d2c0105b..d8cc467cbf0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3220,7 +3220,7 @@ async def test_gather_dep_cancelled_rescheduled(c, s): while fut4.key in b.state.tasks: await asyncio.sleep(0) - assert b.state.tasks[fut2.key].state == "cancelled" + assert b.state.tasks[fut2.key].state == "released" b.block_gather_dep.set() await a.in_get_data.wait() @@ -3258,7 +3258,7 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a) while fut4.key in b.state.tasks: await asyncio.sleep(0.01) - assert b.state.tasks[fut2.key].state == "cancelled" + assert b.state.tasks[fut2.key].state == "released" b.block_gather_dep.set() @@ -3286,7 +3286,7 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a): while fut2.key in b.state.tasks: await asyncio.sleep(0.01) - assert b.state.tasks[fut1.key].state == "cancelled" + assert b.state.tasks[fut1.key].state == "released" b.block_gather_dep.set() while fut2.key in b.state.tasks: diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index d8fd4bb4c41..c3e2cdc4bb2 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -754,6 +754,7 @@ def test_new_replica_while_all_workers_in_flight(ws): assert ws.tasks["y"].state == "flight" +# TODO before merge: move to test_cancelled_state.py @gen_cluster(client=True) async def test_cancelled_while_in_flight(c, s, a, b): event = asyncio.Event() @@ -763,7 +764,7 @@ async def test_cancelled_while_in_flight(c, s, a, b): y = c.submit(inc, x, key="y", workers=[a.address]) await wait_for_state("x", "flight", a) y.release() - await wait_for_state("x", "cancelled", a) + await wait_for_state("x", "released", a) # Let the comm from b to a return the result event.set() @@ -1163,35 +1164,29 @@ def test_task_with_dependencies_acquires_resources(ws): assert ws.available_resources == {"R": 0} -@pytest.mark.parametrize( - "done_ev_cls", - [ - ExecuteSuccessEvent, - pytest.param( - ExecuteFailureEvent, - marks=pytest.mark.xfail( - reason="distributed#6682,distributed#6689,distributed#6693" - ), - ), - ], -) +# TODO before merge: move to test_cancelled_state.py +@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent]) def test_resumed_task_releases_resources(ws_with_running_task, done_ev_cls): ws = ws_with_running_task assert ws.available_resources == {"R": 0} ws2 = "127.0.0.1:2" - ws.handle_stimulus(FreeKeysEvent("cancel", ["x"])) + instructions = ws.handle_stimulus(FreeKeysEvent("cancel", ["x"])) + assert not instructions assert ws.tasks["x"].state == "released" assert ws.available_resources == {"R": 0} instructions = ws.handle_stimulus( ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="compute") ) - assert not instructions + assert instructions == [ + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="compute") + ] assert ws.tasks["x"].state == "flight" assert ws.available_resources == {"R": 0} - ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s2")) + instructions = ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s2")) + assert not instructions assert ws.tasks["x"].state == "flight" assert ws.available_resources == {"R": 1} @@ -1240,11 +1235,7 @@ def test_done_task_not_in_all_running_tasks( "done_ev_cls,done_status", [ (ExecuteSuccessEvent, "memory"), - pytest.param( - ExecuteFailureEvent, - "flight", - marks=pytest.mark.xfail(reason="distributed#6689"), - ), + (ExecuteFailureEvent, "flight"), ], ) def test_done_resumed_task_not_in_all_running_tasks( diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index dbc11d5ec18..213d0e69a39 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1945,10 +1945,12 @@ def _transition_executing_released( else: assert ts in self.long_running running_set = self.long_running + resources = ts.resource_restrictions.copy() recs, instr = self._transition_generic_released(ts, stimulus_id=stimulus_id) # Partially revert _purge_state() running_set.add(ts) + ts.resource_restrictions = resources return recs, instr def _transition_flight_released( @@ -2449,7 +2451,8 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: ts.traceback_text = "" ts.priority = priority ts.duration = ev.duration - ts.resource_restrictions = ev.resource_restrictions + if ts not in self.executing | self.long_running: + ts.resource_restrictions = ev.resource_restrictions ts.annotations = ev.annotations if self.validate: