Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Aug 7, 2022
1 parent b580116 commit 4af7102
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 37 deletions.
17 changes: 7 additions & 10 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,14 @@ def test_workerstate_flight_to_flight(ws):
assert ws.tasks["x"].state == "flight"


def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
def test_workerstate_executing_success_to_fetch(ws_with_running_task):
"""Test state loops:
- executing -> cancelled -> resumed(fetch) -> memory
- executing -> long-running -> cancelled -> resumed(fetch) -> memory
- executing -> cancelled -> resumed(fetch)
- executing -> long-running -> cancelled -> resumed(fetch)
The task execution later terminates successfully.
Test that the task is never fetched and that dependents are unblocked.
Test that the task completion is ignored and the task transitions to flight
See also: test_workerstate_executing_failure_to_fetch
"""
Expand All @@ -660,14 +660,11 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"),
)
assert instructions == [
TaskFinishedMsg.match(key="x", stimulus_id="s3"),
Execute(key="y", stimulus_id="s3"),
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s2")
]
assert ws.tasks["x"].state == "memory"
assert ws.data["x"] == 123
assert ws.tasks["x"].state == "flight"


@pytest.mark.xfail(reason="distributed#6689")
def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
"""Test state loops:
Expand All @@ -692,7 +689,7 @@ def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
ExecuteFailureEvent.dummy("x", stimulus_id="s3"),
)
assert instructions == [
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s3")
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s2")
]
assert ws.tasks["x"].state == "flight"

Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_reschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from distributed import Event, Reschedule, get_worker, secede, wait
from distributed.utils_test import captured_logger, gen_cluster, slowinc
from distributed.worker_state_machine import (
ComputeTaskEvent,
FreeKeysEvent,
RescheduleEvent,
RescheduleMsg,
Expand Down Expand Up @@ -113,7 +114,7 @@ def test_cancelled_reschedule_worker_state(ws_with_running_task):

instructions = ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
assert not instructions
assert ws.tasks["x"].state == "cancelled"
assert ws.tasks["x"].state == "released"
assert ws.available_resources == {"R": 0}

instructions = ws.handle_stimulus(RescheduleEvent(key="x", stimulus_id="s2"))
Expand All @@ -129,3 +130,32 @@ def test_reschedule_releases(ws_with_running_task):
assert instructions == [RescheduleMsg(stimulus_id="s1", key="x")]
assert ws.available_resources == {"R": 1}
assert "x" not in ws.tasks


def test_reschedule_cancelled(ws_with_running_task):
"""Test state loop:
executing -> cancelled -> rescheduled
"""
ws = ws_with_running_task
instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
RescheduleEvent(key="x", stimulus_id="s2"),
)
assert not instructions
assert "x" not in ws.tasks


def test_reschedule_resumed(ws_with_running_task):
"""Test state loop:
executing -> cancelled -> resumed(waiting) -> executing -> rescheduled
"""
ws = ws_with_running_task
instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 1}),
RescheduleEvent(key="x", stimulus_id="s3"),
)
assert instructions == [RescheduleMsg(key="x", stimulus_id="s3")]
assert "x" not in ws.tasks
44 changes: 44 additions & 0 deletions distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from distributed.worker_state_machine import (
ComputeTaskEvent,
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FreeKeysEvent,
TaskFinishedMsg,
Expand Down Expand Up @@ -508,3 +509,46 @@ async def test_resources_from_python_override_config(c, s, a, b):
info = c.scheduler_info()
for worker in [a, b]:
assert info["workers"][worker.address]["resources"] == {"my_resources": 10}


@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent])
def test_cancel_with_resources(ws_with_running_task, done_ev_cls):
ws = ws_with_running_task
assert ws.available_resources == {"R": 0}
ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
assert ws.tasks["x"].state == "released"
assert ws.available_resources == {"R": 0}
ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s2"))
assert "x" not in ws.tasks
assert ws.available_resources == {"R": 1}


@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
)
def test_resume_with_different_resources(ws, done_ev_cls, done_status):
"""A task is cancelled and then resumed to the same state, but with different
resources. This is actually possible in case of manual cancellation from the client,
followed by resubmit.
"""
ws.total_resources = {"R": 1}
ws.available_resources = {"R": 1}

ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.2})
)
assert ws.available_resources == {"R": 0.8}

ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s2"))
assert ws.tasks["x"].state == "released"
assert ws.available_resources == {"R": 0.8}

instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s3", resource_restrictions={"R": 0.3})
)
assert not instructions
assert ws.available_resources == {"R": 0.8}
ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s4"))
assert ws.tasks["x"].state == done_status
assert ws.available_resources == {"R": 1}
2 changes: 1 addition & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def test_steal_worker_state(ws_with_running_task):

ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
assert ws.available_resources == {"R": 0}
assert ws.tasks["x"].state == "cancelled"
assert ws.tasks["x"].state == "released"

instructions = ws.handle_stimulus(ExecuteSuccessEvent.dummy("x", stimulus_id="s2"))
assert not instructions
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3220,7 +3220,7 @@ async def test_gather_dep_cancelled_rescheduled(c, s):
while fut4.key in b.state.tasks:
await asyncio.sleep(0)

assert b.state.tasks[fut2.key].state == "cancelled"
assert b.state.tasks[fut2.key].state == "released"

b.block_gather_dep.set()
await a.in_get_data.wait()
Expand Down Expand Up @@ -3258,7 +3258,7 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a)
while fut4.key in b.state.tasks:
await asyncio.sleep(0.01)

assert b.state.tasks[fut2.key].state == "cancelled"
assert b.state.tasks[fut2.key].state == "released"

b.block_gather_dep.set()

Expand Down Expand Up @@ -3286,7 +3286,7 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a):
while fut2.key in b.state.tasks:
await asyncio.sleep(0.01)

assert b.state.tasks[fut1.key].state == "cancelled"
assert b.state.tasks[fut1.key].state == "released"

b.block_gather_dep.set()
while fut2.key in b.state.tasks:
Expand Down
33 changes: 12 additions & 21 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def test_new_replica_while_all_workers_in_flight(ws):
assert ws.tasks["y"].state == "flight"


# TODO before merge: move to test_cancelled_state.py
@gen_cluster(client=True)
async def test_cancelled_while_in_flight(c, s, a, b):
event = asyncio.Event()
Expand All @@ -763,7 +764,7 @@ async def test_cancelled_while_in_flight(c, s, a, b):
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "flight", a)
y.release()
await wait_for_state("x", "cancelled", a)
await wait_for_state("x", "released", a)

# Let the comm from b to a return the result
event.set()
Expand Down Expand Up @@ -1163,35 +1164,29 @@ def test_task_with_dependencies_acquires_resources(ws):
assert ws.available_resources == {"R": 0}


@pytest.mark.parametrize(
"done_ev_cls",
[
ExecuteSuccessEvent,
pytest.param(
ExecuteFailureEvent,
marks=pytest.mark.xfail(
reason="distributed#6682,distributed#6689,distributed#6693"
),
),
],
)
# TODO before merge: move to test_cancelled_state.py
@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent])
def test_resumed_task_releases_resources(ws_with_running_task, done_ev_cls):
ws = ws_with_running_task
assert ws.available_resources == {"R": 0}
ws2 = "127.0.0.1:2"

ws.handle_stimulus(FreeKeysEvent("cancel", ["x"]))
instructions = ws.handle_stimulus(FreeKeysEvent("cancel", ["x"]))
assert not instructions
assert ws.tasks["x"].state == "released"
assert ws.available_resources == {"R": 0}

instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="compute")
)
assert not instructions
assert instructions == [
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="compute")
]
assert ws.tasks["x"].state == "flight"
assert ws.available_resources == {"R": 0}

ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s2"))
instructions = ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s2"))
assert not instructions
assert ws.tasks["x"].state == "flight"
assert ws.available_resources == {"R": 1}

Expand Down Expand Up @@ -1240,11 +1235,7 @@ def test_done_task_not_in_all_running_tasks(
"done_ev_cls,done_status",
[
(ExecuteSuccessEvent, "memory"),
pytest.param(
ExecuteFailureEvent,
"flight",
marks=pytest.mark.xfail(reason="distributed#6689"),
),
(ExecuteFailureEvent, "flight"),
],
)
def test_done_resumed_task_not_in_all_running_tasks(
Expand Down
5 changes: 4 additions & 1 deletion distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,10 +1945,12 @@ def _transition_executing_released(
else:
assert ts in self.long_running
running_set = self.long_running
resources = ts.resource_restrictions.copy()

recs, instr = self._transition_generic_released(ts, stimulus_id=stimulus_id)
# Partially revert _purge_state()
running_set.add(ts)
ts.resource_restrictions = resources
return recs, instr

def _transition_flight_released(
Expand Down Expand Up @@ -2449,7 +2451,8 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs:
ts.traceback_text = ""
ts.priority = priority
ts.duration = ev.duration
ts.resource_restrictions = ev.resource_restrictions
if ts not in self.executing | self.long_running:
ts.resource_restrictions = ev.resource_restrictions
ts.annotations = ev.annotations

if self.validate:
Expand Down

0 comments on commit 4af7102

Please sign in to comment.