diff --git a/distributed/client.py b/distributed/client.py index ce50e2a6ada..46982203530 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1497,8 +1497,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self._close( # if we're handling an exception, we assume that it's more # important to deliver that exception than shutdown gracefully. - fast=exc_type - is not None + fast=(exc_type is not None) ) def __exit__(self, exc_type, exc_value, traceback): @@ -1669,16 +1668,15 @@ async def _wait_for_handle_report_task(self, fast=False): await wait_for(handle_report_task, 0 if fast else 2) @log_errors - async def _close(self, fast=False): - """ - Send close signal and wait until scheduler completes + async def _close(self, fast: bool = False) -> None: + """Send close signal and wait until scheduler completes If fast is True, the client will close forcefully, by cancelling tasks the background _handle_report_task. """ - # TODO: aclose more forcefully by aborting the RPC and cancelling all + # TODO: close more forcefully by aborting the RPC and cancelling all # background tasks. - # see https://trio.readthedocs.io/en/stable/reference-io.html#trio.aclose_forcefully + # See https://trio.readthedocs.io/en/stable/reference-io.html#trio.aclose_forcefully if self.status == "closed": return diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c00c2bbf2cd..b491539416e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5196,7 +5196,7 @@ def test_quiet_client_close(loop): threads_per_worker=4, ) as c: futures = c.map(slowinc, range(1000), delay=0.01) - sleep(0.200) # stop part-way + sleep(0.2) # stop part-way sleep(0.1) # let things settle out = logger.getvalue() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index c581ea14c46..636efcd9e48 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3058,7 +3058,7 @@ async def connect(self, *args, **kwargs): @gen_cluster(client=True) -async def test_gather_failing_cnn_recover(c, s, a, b): +async def test_gather_failing_can_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=1) with mock.patch.object(s, "rpc", rpc), dask.config.set(