Skip to content

Commit

Permalink
Ensure resumed flight tasks are still fetched (#5426)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Oct 22, 2021
1 parent cdc68cc commit 1670cf8
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 124 deletions.
4 changes: 2 additions & 2 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@

@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 10)
x = c.submit(inc, 10, key="x")
assert not x.done()

assert isinstance(x, Future)
Expand All @@ -112,7 +112,7 @@ async def test_submit(c, s, a, b):
assert result == 11
assert x.done()

y = c.submit(inc, 20)
y = c.submit(inc, 20, key="y")
z = c.submit(add, x, y)

result = await z
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,8 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
slowinc,
range(10),
key=[f"f1-{ix}" for ix in range(10)],
workers=[w0.address],
allow_other_workers=True,
)
while not w0.active_keys:
await asyncio.sleep(0.01)
Expand Down
89 changes: 69 additions & 20 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,14 +1804,14 @@ async def test_story_with_deps(c, s, a, b):
stimulus_ids.add(msg[-2])
pruned_story.append(tuple(pruned_msg[:-2]))

assert len(stimulus_ids) == 3
assert len(stimulus_ids) == 3, stimulus_ids
stimulus_id = pruned_story[0][-1]
assert isinstance(stimulus_id, str)
assert stimulus_id.startswith("compute-task")
# This is a simple transition log
expected_story = [
(key, "compute-task"),
(key, "released", "waiting", {}),
(key, "released", "waiting", {dep.key: "fetch"}),
(key, "waiting", "ready", {}),
(key, "ready", "executing", {}),
(key, "put-in-memory"),
Expand All @@ -1832,11 +1832,11 @@ async def test_story_with_deps(c, s, a, b):
stimulus_ids.add(msg[-2])
pruned_story.append(tuple(pruned_msg[:-2]))

assert len(stimulus_ids) == 3
assert len(stimulus_ids) == 2, stimulus_ids
stimulus_id = pruned_story[0][-1]
assert isinstance(stimulus_id, str)
expected_story = [
(dep_story, "register-replica", "released"),
(dep_story, "ensure-task-exists", "released"),
(dep_story, "released", "fetch", {}),
(
"gather-dependencies",
Expand Down Expand Up @@ -2794,7 +2794,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
_acquire_replicas(s, b, fut)

await futC
while fut.key not in b.tasks:
while fut.key not in b.tasks or not b.tasks[fut.key].state == "memory":
await asyncio.sleep(0.005)
assert len(s.who_has[fut.key]) == 2

Expand Down Expand Up @@ -3082,12 +3082,14 @@ def clear_leak():
]


async def _wait_for_flight(key, worker):
while key not in worker.tasks or worker.tasks[key].state != "flight":
async def _wait_for_state(key: str, worker: Worker, state: str):
# Keep the sleep interval at 0 since the tests using this are very sensitive
# about timing. they intend to capture loop cycles after this specific
# condition was set
while key not in worker.tasks or worker.tasks[key].state != state:
await asyncio.sleep(0)


@pytest.mark.xfail(reason="#5406")
@gen_cluster(client=True)
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand All @@ -3107,21 +3109,26 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a,

fut2_key = fut2.key

await _wait_for_flight(fut2_key, b)
await _wait_for_state(fut2_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

story_before = b.story(fut2.key)
assert fut2.key in mocked_gather.call_args.kwargs["to_gather"]
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
story_after = b.story(fut2.key)
assert story_before == story_after
assert b.tasks[fut2.key].state == "cancelled"
args, kwargs = mocked_gather.call_args
assert fut2.key in kwargs["to_gather"]

await Worker.gather_dep(b, *args, **kwargs)
assert fut2.key not in b.tasks
f2_story = b.story(fut2.key)
assert f2_story
assert not any("missing-dep" in msg for msg in b.story(fut2.key))
await fut3


@pytest.mark.xfail(reason="#5406")
@gen_cluster(
client=True,
config={
Expand All @@ -3137,13 +3144,55 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b):

fut1_key = fut1.key

await _wait_for_flight(fut1_key, b)
await _wait_for_state(fut1_key, b, "flight")
while not mocked_gather.call_args:
await asyncio.sleep(0)

fut2.release()
while fut2.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut1.key] != "flight"
log_before = list(b.log)
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
assert log_before == list(b.log)
assert b.tasks[fut1.key].state == "cancelled"

args, kwargs = mocked_gather.call_args
await Worker.gather_dep(b, *args, **kwargs)

assert fut2.key not in b.tasks
f1_story = b.story(fut1.key)
assert f1_story
assert not any("missing-dep" in msg for msg in b.story(fut2.key))


@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
@pytest.mark.parametrize("close_worker", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
c, s, a, b, x, intermediate_state, close_worker
):
"""If a task was transitioned to in-flight, the gather-dep coroutine was
scheduled but a cancel request came in before gather_data_from_worker was
issued this might corrupt the state machine if the cancelled key is not
properly handled"""

fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1")
fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B")
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
await fut2
with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")

s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})
while not mocked_gather.call_args:
await asyncio.sleep(0)

await s.remove_worker(address=x.address, safe=True, close=close_worker)

await _wait_for_state(fut2_key, b, intermediate_state)

args, kwargs = mocked_gather.call_args
await Worker.gather_dep(b, *args, **kwargs)
await fut3
Loading

0 comments on commit 1670cf8

Please sign in to comment.