Skip to content

Commit

Permalink
Fix a deadlock connected to task stealing and task deserialization (#…
Browse files Browse the repository at this point in the history
…5128)

* Fix a deadlock connected to task stealing and deserialization

If a task is stolen while the task runspec is being deserialized this allows
for an edge case where the executing_count is never decreased again
such that the ready queue is never worked off

* Simplify exception handling for Worker.execute

* remove test about unintended behaviour

* function naming fix
  • Loading branch information
fjetter authored Jul 30, 2021
1 parent 3f1b250 commit 1999c15
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 95 deletions.
103 changes: 100 additions & 3 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import traceback
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from numbers import Number
from operator import add
from time import sleep
Expand Down Expand Up @@ -2005,17 +2006,53 @@ async def test_process_executor(c, s, a, b):
assert (await future) != os.getpid()


def kill_process():
import os
import signal

os.kill(os.getpid(), signal.SIGTERM)


@gen_cluster(client=True)
async def test_process_executor_kills_process(c, s, a, b):
with ProcessPoolExecutor() as e:
a.executors["processes"] = e
b.executors["processes"] = e
with dask.annotate(executor="processes", retries=1):
future = c.submit(kill_process)

with pytest.raises(
BrokenProcessPool,
match="A child process terminated abruptly, the process pool is not usable anymore",
):
await future

with dask.annotate(executor="processes", retries=1):
future = c.submit(inc, 1)

# FIXME: The processpool is now unusable and the worker is effectively
# dead
with pytest.raises(
BrokenProcessPool,
match="A child process terminated abruptly, the process pool is not usable anymore",
):
assert await future == 2


def raise_exc():
raise RuntimeError("foo")


@gen_cluster(client=True)
async def test_process_executor_raise_exception(c, s, a, b):
with ProcessPoolExecutor() as e:
a.executors["processes"] = e
b.executors["processes"] = e
with dask.annotate(executor="processes", retries=1):
future = c.submit(sys.exit, 1)
future = c.submit(raise_exc)

exc = await future.exception()
assert "SystemExit(1)" in repr(exc)
with pytest.raises(RuntimeError, match="foo"):
await future


def assert_task_states_on_worker(expected, worker):
Expand Down Expand Up @@ -2417,3 +2454,63 @@ async def test_forget_dependents_after_release(c, s, a):
while fut2.key in a.tasks:
await asyncio.sleep(0.001)
assert fut2.key not in {d.key for d in a.tasks[fut.key].dependents}


@gen_cluster(client=True, nthreads=[("", 1)] * 2, timeout=5000000)
async def test_steak_during_task_deserialization(c, s, a, b, monkeypatch):
stealing_ext = s.extensions["stealing"]
stealing_ext._pc.stop()
from distributed.utils import ThreadPoolExecutor

class CountingThreadPool(ThreadPoolExecutor):
counter = 0

def submit(self, *args, **kwargs):
CountingThreadPool.counter += 1
return super().submit(*args, **kwargs)

# Ensure we're always offloading
monkeypatch.setattr("distributed.worker.OFFLOAD_THRESHOLD", 1)
threadpool = CountingThreadPool(
max_workers=1, thread_name_prefix="Counting-Offload-Threadpool"
)
try:
monkeypatch.setattr("distributed.utils._offload_executor", threadpool)

class SlowDeserializeCallable:
def __init__(self, delay=0.1):
self.delay = delay

def __getstate__(self):
return self.delay

def __setstate__(self, state):
delay = state
import time

time.sleep(delay)
return SlowDeserializeCallable(delay)

def __call__(self, *args, **kwargs):
return 41

slow_deserialized_func = SlowDeserializeCallable()
fut = c.submit(
slow_deserialized_func, 1, workers=[a.address], allow_other_workers=True
)

while CountingThreadPool.counter == 0:
await asyncio.sleep(0)

ts = s.tasks[fut.key]
a.steal_request(fut.key)
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])

assert await fut2 == 42
await fut3

finally:
threadpool.shutdown()
151 changes: 59 additions & 92 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,12 +2817,8 @@ async def _maybe_deserialize_task(self, ts):
{"action": "deserialize", "start": start, "stop": stop}
)
return function, args, kwargs
except Exception as e:
logger.warning("Could not deserialize task", exc_info=True)
emsg = error_message(e)
emsg["key"] = ts.key
emsg["op"] = "task-erred"
self.batched_stream.send(emsg)
except Exception:
logger.error("Could not deserialize task", exc_info=True)
self.log.append((ts.key, "deserialize-error"))
raise

Expand All @@ -2838,11 +2834,6 @@ async def ensure_computing(self):
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
try:
# Ensure task is deserialized prior to execution
ts.runspec = await self._maybe_deserialize_task(ts)
except Exception:
continue
self.transition(ts, "executing")
else:
break
Expand All @@ -2857,11 +2848,6 @@ async def ensure_computing(self):
elif ts.key in self.data:
self.transition(ts, "memory")
elif ts.state in READY:
try:
# Ensure task is deserialized prior to execution
ts.runspec = await self._maybe_deserialize_task(ts)
except Exception:
continue
self.transition(ts, "executing")
except Exception as e:
logger.exception(e)
Expand All @@ -2871,61 +2857,38 @@ async def ensure_computing(self):
pdb.set_trace()
raise

async def execute(self, key, report=False):
executor_error = None
async def execute(self, key):
if self.status in (Status.closing, Status.closed, Status.closing_gracefully):
return

if key not in self.tasks:
return

ts = self.tasks[key]

if ts.state != "executing":
# This might happen if keys are canceled
logger.debug(
"Trying to execute a task %s which is not in executing state anymore"
% ts
)
return

try:
if key not in self.tasks:
return
ts = self.tasks[key]
if ts.state != "executing":
# This might happen if keys are canceled
logger.debug(
"Trying to execute a task %s which is not in executing state anymore"
% ts
)
return
if ts.runspec is None:
logger.critical("No runspec available for task %s." % ts)
if self.validate:
assert not ts.waiting_for_data
assert ts.state == "executing"
assert ts.runspec is not None

function, args, kwargs = ts.runspec

start = time()
data = {}
for dep in ts.dependencies:
k = dep.key
try:
data[k] = self.data[k]
except KeyError:
from .actor import Actor # TODO: create local actor
function, args, kwargs = await self._maybe_deserialize_task(ts)

data[k] = Actor(type(self.actors[k]), self.address, k, self)
args2 = pack_data(args, data, key_types=(bytes, str))
kwargs2 = pack_data(kwargs, data, key_types=(bytes, str))
stop = time()
if stop - start > 0.005:
ts.startstops.append(
{"action": "disk-read", "start": start, "stop": stop}
)
if self.digests is not None:
self.digests["disk-load-duration"].add(stop - start)
args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)

if ts.annotations is not None and "executor" in ts.annotations:
executor = ts.annotations["executor"]
else:
executor = "default"
assert executor in self.executors

logger.debug(
"Execute key: %s worker: %s, executor: %s",
ts.key,
self.address,
executor,
) # TODO: comment out?
assert key == ts.key
self.active_keys.add(ts.key)
try:
Expand All @@ -2945,41 +2908,26 @@ async def execute(self, key, report=False):
self.scheduler_delay,
)
else:
try:
start = time() + self.scheduler_delay
result = await self.loop.run_in_executor(
e,
apply_function_simple,
function,
args2,
kwargs2,
self.scheduler_delay,
)
except BaseException as e:
msg = error_message(e)
msg["op"] = "task-erred"
msg["actual-exception"] = e
msg["start"] = start
msg["stop"] = time() + self.scheduler_delay
msg["thread"] = None
result = msg

except RuntimeError as e:
executor_error = e
raise
result = await self.loop.run_in_executor(
e,
apply_function_simple,
function,
args2,
kwargs2,
self.scheduler_delay,
)
finally:
self.active_keys.discard(ts.key)

# We'll need to check again for the task state since it may have
# changed since the execution was kicked off. In particular, it may
# have been canceled and released already in which case we'll have
# to drop the result immediately
key = ts.key
ts = self.tasks.get(key)

if ts is None:
if ts.key not in self.tasks:
logger.debug(
"Dropping result for %s since task has already been released." % key
"Dropping result for %s since task has already been released."
% ts.key
)
return

Expand Down Expand Up @@ -3022,18 +2970,37 @@ async def execute(self, key, report=False):
assert ts.state != "executing"
assert not ts.waiting_for_data

except Exception as exc:
logger.error(
"Exception during execution of task %s.", ts.key, exc_info=True
)
emsg = error_message(exc)
ts.exception = emsg["exception"]
ts.traceback = emsg["traceback"]
self.transition(ts, "error")
finally:
await self.ensure_computing()
self.ensure_communicating()
except Exception as e:
if executor_error is e:
logger.error("Thread Pool Executor error: %s", e)
else:
logger.exception(e)
if LOG_PDB:
import pdb

pdb.set_trace()
raise
def _prepare_args_for_execution(self, ts, args, kwargs):
start = time()
data = {}
for dep in ts.dependencies:
k = dep.key
try:
data[k] = self.data[k]
except KeyError:
from .actor import Actor # TODO: create local actor

data[k] = Actor(type(self.actors[k]), self.address, k, self)
args2 = pack_data(args, data, key_types=(bytes, str))
kwargs2 = pack_data(kwargs, data, key_types=(bytes, str))
stop = time()
if stop - start > 0.005:
ts.startstops.append({"action": "disk-read", "start": start, "stop": stop})
if self.digests is not None:
self.digests["disk-load-duration"].add(stop - start)
return args2, kwargs2

##################
# Administrative #
Expand Down

0 comments on commit 1999c15

Please sign in to comment.