Skip to content

Commit

Permalink
Ensure NannyPlugins are always installed (dask#8107)
Browse files Browse the repository at this point in the history
Co-authored-by: Hendrik Makait <[email protected]>
  • Loading branch information
fjetter and hendrikmakait authored Sep 1, 2023
1 parent e79c0c7 commit 9b15cd5
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 22 deletions.
34 changes: 26 additions & 8 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,35 @@ async def start_unsafe(self):
self.ip = get_address_host(self.address)

await self.preloads.start()
saddr = self.scheduler.addr
comm = await self.rpc.connect(saddr)
comm.name = "Nanny->Scheduler (registration)"

msg = await self.scheduler.register_nanny()
for name, plugin in msg["nanny-plugins"].items():
await self.plugin_add(plugin=plugin, name=name)
try:
await comm.write({"op": "register_nanny", "address": self.address})
msg = await comm.read()
try:
for name, plugin in msg["nanny-plugins"].items():
await self.plugin_add(plugin=plugin, name=name)

logger.info(" Start Nanny at: %r", self.address)
response = await self.instantiate()
logger.info(" Start Nanny at: %r", self.address)
response = await self.instantiate()

if response != Status.running:
await self.close(reason="nanny-start-failed")
return
if response != Status.running:
raise RuntimeError("Nanny failed to start worker process")
except Exception:
try:
await comm.write({"status": "error"})

# If self.instantiate() failed, the comm will already be closed.
except CommClosedError:
pass
await self.close(reason="nanny-start-failed")
raise
else:
await comm.write({"status": "ok"})
finally:
await comm.close()

assert self.worker_address

Expand Down
43 changes: 30 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3623,6 +3623,8 @@ def __init__(
self.event_subscriber = defaultdict(set)
self.worker_plugins = {}
self.nanny_plugins = {}
self._starting_nannies = set()
self._starting_nannies_cond = asyncio.Condition()

worker_handlers = {
"task-finished": self.handle_task_finished,
Expand Down Expand Up @@ -4295,12 +4297,20 @@ async def add_worker(
# This will keep running until the worker is removed
await self.handle_worker(comm, address)

async def add_nanny(self) -> dict[str, Any]:
msg = {
"status": "OK",
"nanny-plugins": self.nanny_plugins,
}
return msg
async def add_nanny(self, comm: Comm, address: str) -> None:
async with self._starting_nannies_cond:
self._starting_nannies.add(address)
try:
msg = {
"status": "OK",
"nanny-plugins": self.nanny_plugins,
}
await comm.write(msg)
await comm.read()
finally:
async with self._starting_nannies_cond:
self._starting_nannies.discard(address)
self._starting_nannies_cond.notify_all()

def _match_graph_with_tasks(
self, dsk: dict[str, Any], dependencies: dict[str, set[str]], keys: set[str]
Expand Down Expand Up @@ -7441,6 +7451,7 @@ def stop_task_metadata(self, name: str | None = None) -> dict:

async def register_worker_plugin(self, comm, plugin, name=None):
"""Registers a worker plugin on all running and future workers"""
logger.info("Registering Worker plugin %s", name)
self.worker_plugins[name] = plugin

responses = await self.broadcast(
Expand All @@ -7458,15 +7469,21 @@ async def unregister_worker_plugin(self, comm, name):
responses = await self.broadcast(msg=dict(op="plugin-remove", name=name))
return responses

async def register_nanny_plugin(self, comm, plugin, name=None):
async def register_nanny_plugin(self, comm, plugin, name):
"""Registers a setup function, and call it on every worker"""
logger.info("Registering Nanny plugin %s", name)
self.nanny_plugins[name] = plugin

responses = await self.broadcast(
msg=dict(op="plugin_add", plugin=plugin, name=name),
nanny=True,
)
return responses
async with self._starting_nannies_cond:
if self._starting_nannies:
logger.info("Waiting for Nannies to start %s", self._starting_nannies)
await self._starting_nannies_cond.wait_for(
lambda: not self._starting_nannies
)
responses = await self.broadcast(
msg=dict(op="plugin_add", plugin=plugin, name=name),
nanny=True,
)
return responses

async def unregister_nanny_plugin(self, comm, name):
"""Unregisters a worker plugin"""
Expand Down
140 changes: 139 additions & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import sys
import warnings
import weakref
from contextlib import suppress
from unittest import mock

Expand All @@ -23,7 +24,7 @@
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import CommClosedError, Status, error_message
from distributed.diagnostics import SchedulerPlugin
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.diagnostics.plugin import NannyPlugin, WorkerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.utils import TimeoutError, get_mp_context, parse_ports
Expand Down Expand Up @@ -781,3 +782,140 @@ class C:
"traceback_text": "",
},
] == [msg[1] for msg in s.get_events("test-topic4")]


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_nanny_plugin_simple(c, s, a):
"""A plugin should be registered to already existing workers but also to new ones."""
plugin = DummyNannyPlugin("foo")
await c.register_worker_plugin(plugin)
assert a._plugin_registered
async with Nanny(s.address) as n:
assert n._plugin_registered


class DummyNannyPlugin(NannyPlugin):
def __init__(self, name, restart=False):
self.restart = restart
self.name = name
self.nanny = None

def setup(self, nanny):
print(f"Setup on {nanny}")
self.nanny = weakref.ref(nanny)
nanny._plugin_registered = True

def teardown(self, nanny):
nanny._plugin_registered = False


class SlowNanny(Nanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.in_instantiate = asyncio.Event()
self.wait_instantiate = asyncio.Event()

async def instantiate(self):
self.in_instantiate.set()
await self.wait_instantiate.wait()
return await super().instantiate()


@pytest.mark.parametrize("restart", [True, False])
@gen_cluster(client=True, nthreads=[])
async def test_nanny_plugin_register_during_start_success(c, s, restart):
plugin = DummyNannyPlugin("foo", restart=restart)
n = SlowNanny(s.address)
assert not hasattr(n, "_plugin_registered")
start = asyncio.create_task(n.start())
try:
await n.in_instantiate.wait()

register = asyncio.create_task(c.register_worker_plugin(plugin))
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(register), timeout=0.1)
n.wait_instantiate.set()
assert await register
await start
assert n._plugin_registered
finally:
start.cancel()
await n.close()


class SlowBrokenNanny(Nanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.in_instantiate = asyncio.Event()
self.wait_instantiate = asyncio.Event()

async def instantiate(self):
self.in_instantiate.set()
await self.wait_instantiate.wait()
raise RuntimeError("Nope")


@pytest.mark.parametrize("restart", [True, False])
@gen_cluster(client=True, nthreads=[])
async def test_nanny_plugin_register_during_start_failure(c, s, restart):
plugin = DummyNannyPlugin("foo", restart=restart)
n = SlowBrokenNanny(s.address)
assert not hasattr(n, "_plugin_registered")
start = asyncio.create_task(n.start())
await n.in_instantiate.wait()

register = asyncio.create_task(c.register_worker_plugin(plugin))
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(register), timeout=0.1)
n.wait_instantiate.set()
with pytest.raises(RuntimeError):
await start
assert not await register


class SlowDistNanny(Nanny):
def __init__(self, *args, in_instantiate, wait_instantiate, **kwargs):
super().__init__(*args, **kwargs)
self.in_instantiate = in_instantiate
self.wait_instantiate = wait_instantiate

async def instantiate(self):
self.in_instantiate.set()
self.wait_instantiate.wait()
return await super().instantiate()


def run_nanny(scheduler_addr, in_instantiate, wait_instantiate):
async def _():
worker = await SlowDistNanny(
scheduler_addr,
wait_instantiate=wait_instantiate,
in_instantiate=in_instantiate,
)
await worker.finished()

asyncio.run(_())


@pytest.mark.parametrize("restart", [True, False])
@gen_cluster(client=True, nthreads=[])
async def test_nanny_plugin_register_nanny_killed(c, s, restart):
in_instantiate = get_mp_context().Event()
wait_instantiate = get_mp_context().Event()
proc = get_mp_context().Process(
name="run_nanny",
target=run_nanny,
kwargs={
"in_instantiate": in_instantiate,
"wait_instantiate": wait_instantiate,
},
args=(s.address,),
)
proc.start()
try:
plugin = DummyNannyPlugin("foo", restart=restart)
await asyncio.to_thread(in_instantiate.wait)
register = asyncio.create_task(c.register_worker_plugin(plugin))
finally:
proc.kill()
assert await register == {}

0 comments on commit 9b15cd5

Please sign in to comment.