Skip to content

Commit

Permalink
Handle failing plugin.close() calls during scheduler shutdown (#6450)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin authored May 26, 2022
1 parent 1db476c commit 9a70a53
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
36 changes: 35 additions & 1 deletion distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging

import pytest

from distributed import Scheduler, SchedulerPlugin, Worker, get_worker
from distributed.utils_test import gen_cluster, gen_test, inc
from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc


@gen_cluster(client=True)
Expand Down Expand Up @@ -209,3 +211,35 @@ async def start(self, scheduler: Scheduler) -> None:
await s.register_scheduler_plugin(MyPlugin())

assert s._foo == "bar"


@gen_cluster(client=True)
async def test_closing_errors_ok(c, s, a, b, capsys):
class OK(SchedulerPlugin):
async def before_close(self):
print(123)

async def close(self):
print(456)

class Bad(SchedulerPlugin):
async def before_close(self):
raise Exception("BEFORE_CLOSE")

async def close(self):
raise Exception("AFTER_CLOSE")

await s.register_scheduler_plugin(OK())
await s.register_scheduler_plugin(Bad())

with captured_logger(logging.getLogger("distributed.scheduler")) as logger:
await s.close()

out, err = capsys.readouterr()
assert "123" in out
assert "456" in out

text = logger.getvalue()
assert "BEFORE_CLOSE" in text
text = logger.getvalue()
assert "AFTER_CLOSE" in text
10 changes: 8 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3358,8 +3358,14 @@ async def close(self):
await self.finished()
return

async def log_errors(func):
try:
await func()
except Exception:
logger.exception("Plugin call failed during scheduler.close")

await asyncio.gather(
*[plugin.before_close() for plugin in list(self.plugins.values())]
*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]
)

self.status = Status.closing
Expand All @@ -3371,7 +3377,7 @@ async def close(self):
await preload.teardown()

await asyncio.gather(
*[plugin.close() for plugin in list(self.plugins.values())]
*[log_errors(plugin.close) for plugin in list(self.plugins.values())]
)

for pc in self.periodic_callbacks.values():
Expand Down

0 comments on commit 9a70a53

Please sign in to comment.