From 0651e866fe5dc0aadd3fe3af54d589f8763f86fd Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 12 Aug 2022 14:03:30 +0100 Subject: [PATCH 1/3] use async with Client in tests --- distributed/deploy/tests/test_adaptive.py | 93 ++-- distributed/deploy/tests/test_local.py | 64 ++- distributed/tests/test_client.py | 567 ++++++++++------------ distributed/tests/test_failed_workers.py | 56 ++- distributed/tests/test_publish.py | 100 ++-- distributed/tests/test_queues.py | 62 ++- distributed/tests/test_scheduler.py | 13 +- distributed/tests/test_utils_test.py | 7 +- distributed/tests/test_variable.py | 62 ++- distributed/tests/test_worker.py | 5 +- distributed/utils_test.py | 159 +++--- 11 files changed, 572 insertions(+), 616 deletions(-) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 1248db0f0d8..04b518b063b 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -117,50 +117,47 @@ def scale_up(self, n, **kwargs): @gen_test() async def test_min_max(): - cluster = await LocalCluster( + async with LocalCluster( n_workers=0, silence_logs=False, processes=False, dashboard_address=":0", asynchronous=True, threads_per_worker=1, - ) - try: + ) as cluster: adapt = cluster.adapt(minimum=1, maximum=2, interval="20 ms", wait_count=10) - c = await Client(cluster, asynchronous=True) - - start = time() - while not cluster.scheduler.workers: - await asyncio.sleep(0.01) - assert time() < start + 1 - - await asyncio.sleep(0.2) - assert len(cluster.scheduler.workers) == 1 - assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} + async with Client(cluster, asynchronous=True) as c: + start = time() + while not cluster.scheduler.workers: + await asyncio.sleep(0.01) + assert time() < start + 1 - futures = c.map(slowinc, range(100), delay=0.1) + await asyncio.sleep(0.2) + assert len(cluster.scheduler.workers) == 1 + assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} - start = time() - while len(cluster.scheduler.workers) < 2: - await asyncio.sleep(0.01) - assert time() < start + 1 + futures = c.map(slowinc, range(100), delay=0.1) - assert len(cluster.scheduler.workers) == 2 - await asyncio.sleep(0.5) - assert len(cluster.scheduler.workers) == 2 - assert len(cluster.workers) == 2 - assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log) + start = time() + while len(cluster.scheduler.workers) < 2: + await asyncio.sleep(0.01) + assert time() < start + 1 + + assert len(cluster.scheduler.workers) == 2 + await asyncio.sleep(0.5) + assert len(cluster.scheduler.workers) == 2 + assert len(cluster.workers) == 2 + assert len(adapt.log) == 2 and all( + d["status"] == "up" for _, d in adapt.log + ) - del futures + del futures - start = time() - while len(cluster.scheduler.workers) != 1: - await asyncio.sleep(0.01) - assert time() < start + 2 - assert adapt.log[-1][1]["status"] == "down" - finally: - await c.close() - await cluster.close() + start = time() + while len(cluster.scheduler.workers) != 1: + await asyncio.sleep(0.01) + assert time() < start + 2 + assert adapt.log[-1][1]["status"] == "down" @gen_test() @@ -194,16 +191,14 @@ async def test_adapt_quickly(): Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = await LocalCluster( + async with LocalCluster( n_workers=0, asynchronous=True, processes=False, silence_logs=False, dashboard_address=":0", - ) - client = await Client(cluster, asynchronous=True) - adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) - try: + ) as cluster, Client(cluster, asynchronous=True) as client: + adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) future = client.submit(slowinc, 1, delay=0.100) await wait(future) assert len(adapt.log) == 1 @@ -241,9 +236,6 @@ async def test_adapt_quickly(): await asyncio.sleep(0.1) assert len(cluster.workers) == 1 - finally: - await client.close() - await cluster.close() @gen_test() @@ -255,20 +247,19 @@ async def test_adapt_down(): processes=False, silence_logs=False, dashboard_address=":0", - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - cluster.adapt(interval="20ms", maximum=5) + ) as cluster, Client(cluster, asynchronous=True) as client: + cluster.adapt(interval="20ms", maximum=5) - futures = client.map(slowinc, range(1000), delay=0.1) - while len(cluster.scheduler.workers) < 5: - await asyncio.sleep(0.1) + futures = client.map(slowinc, range(1000), delay=0.1) + while len(cluster.scheduler.workers) < 5: + await asyncio.sleep(0.1) - cluster.adapt(maximum=2) + cluster.adapt(maximum=2) - start = time() - while len(cluster.scheduler.workers) != 2: - await asyncio.sleep(0.1) - assert time() < start + 60 + start = time() + while len(cluster.scheduler.workers) != 2: + await asyncio.sleep(0.1) + assert time() < start + 60 @gen_test() diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 1da51d9f4e4..bc64cfca0b6 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -5,11 +5,11 @@ import sys from threading import Lock from time import sleep +from unittest import mock from urllib.parse import urlparse import pytest from tornado.httpclient import AsyncHTTPClient -from tornado.ioloop import IOLoop from dask.system import CPU_COUNT @@ -245,19 +245,21 @@ def test_Client_solo(loop): @gen_test() async def test_duplicate_clients(): pytest.importorskip("bokeh") - c1 = await Client( + async with Client( processes=False, silence_logs=False, dashboard_address=9876, asynchronous=True - ) - with pytest.warns(Warning) as info: - c2 = await Client( - processes=False, - silence_logs=False, - dashboard_address=9876, - asynchronous=True, - ) - - assert "dashboard" in c1.cluster.scheduler.services - assert "dashboard" in c2.cluster.scheduler.services + ) as c1: + c1_services = c1.cluster.scheduler.services + with pytest.warns(Warning) as info: + async with Client( + processes=False, + silence_logs=False, + dashboard_address=9876, + asynchronous=True, + ) as c2: + c2_services = c2.cluster.scheduler.services + + assert c1_services == {"dashboard": mock.ANY} + assert c2_services == {"dashboard": mock.ANY} assert any( all( @@ -266,8 +268,6 @@ async def test_duplicate_clients(): ) for msg in info.list ) - await c1.close() - await c2.close() def test_Client_kwargs(loop): @@ -824,35 +824,29 @@ class MyCluster(LocalCluster): def scale_down(self, *args, **kwargs): pass - loop = IOLoop.current() - cluster = await MyCluster( + async with MyCluster( n_workers=0, processes=False, silence_logs=False, dashboard_address=":0", - loop=loop, + loop=None, asynchronous=True, - ) - c = await Client(cluster, asynchronous=True) - - assert not cluster.workers - - await cluster.scale(2) + ) as cluster, Client(cluster, asynchronous=True) as c: + assert not cluster.workers - start = time() - while len(cluster.scheduler.workers) != 2: - await asyncio.sleep(0.01) - assert time() < start + 3 + await cluster.scale(2) - await cluster.scale(1) + start = time() + while len(cluster.scheduler.workers) != 2: + await asyncio.sleep(0.01) + assert time() < start + 3 - start = time() - while len(cluster.scheduler.workers) != 1: - await asyncio.sleep(0.01) - assert time() < start + 3 + await cluster.scale(1) - await c.close() - await cluster.close() + start = time() + while len(cluster.scheduler.workers) != 1: + await asyncio.sleep(0.01) + assert time() < start + 3 def test_local_tls_restart(loop): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fbf2af7f99d..334b43f3fdd 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6,6 +6,7 @@ import gc import inspect import logging +import operator import os import pathlib import pickle @@ -534,17 +535,14 @@ async def test_exceptions(c, s, a, b): @gen_cluster() async def test_gc(s, a, b): - c = await Client(s.address, asynchronous=True) - - x = c.submit(inc, 10) - await x - assert s.tasks[x.key].who_has - x.__del__() - await async_wait_for( - lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 - ) - - await c.close() + async with Client(s.address, asynchronous=True) as c: + x = c.submit(inc, 10) + await x + assert s.tasks[x.key].who_has + x.__del__() + await async_wait_for( + lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 + ) def test_thread(c): @@ -587,13 +585,12 @@ async def test_gather(c, s, a, b): @gen_cluster(client=True) async def test_gather_mismatched_client(c, s, a, b): - c2 = await Client(s.address, asynchronous=True) + async with Client(s.address, asynchronous=True) as c2: + x = c.submit(inc, 10) + y = c2.submit(inc, 5) - x = c.submit(inc, 10) - y = c2.submit(inc, 5) - - with pytest.raises(ValueError, match="Futures created by another client"): - await c.gather([x, y]) + with pytest.raises(ValueError, match="Futures created by another client"): + await c.gather([x, y]) @gen_cluster(client=True) @@ -1067,20 +1064,16 @@ async def test_map_quotes(c, s, a, b): @gen_cluster() async def test_two_consecutive_clients_share_results(s, a, b): - c = await Client(s.address, asynchronous=True) - - x = c.submit(random.randint, 0, 1000, pure=True) - xx = await x - - f = await Client(s.address, asynchronous=True) + async with Client(s.address, asynchronous=True) as c: - y = f.submit(random.randint, 0, 1000, pure=True) - yy = await y + x = c.submit(random.randint, 0, 1000, pure=True) + xx = await x - assert xx == yy + async with Client(s.address, asynchronous=True) as f: + y = f.submit(random.randint, 0, 1000, pure=True) + yy = await y - await c.close() - await f.close() + assert xx == yy @gen_cluster(client=True) @@ -1423,14 +1416,10 @@ async def test_scatter_direct(c, s, a, b): @gen_cluster() async def test_scatter_direct_2(s, a, b): - c = await Client(s.address, asynchronous=True, heartbeat_interval=10) - - last = s.clients[c.id].last_seen - - while s.clients[c.id].last_seen == last: - await asyncio.sleep(0.10) - - await c.close() + async with Client(s.address, asynchronous=True, heartbeat_interval=10) as c: + last = s.clients[c.id].last_seen + while s.clients[c.id].last_seen == last: + await asyncio.sleep(0.10) @gen_cluster(client=True) @@ -2014,21 +2003,26 @@ async def test_badly_serialized_input_stderr(capsys, c): assert future.status == "error" -def test_repr(loop): - funcs = [str, repr, lambda x: x._repr_html_()] +@pytest.mark.parametrize( + "func", + [ + str, + repr, + operator.methodcaller("_repr_html_"), + ], +) +def test_repr(loop, func): with cluster(nworkers=3, worker_kwargs={"memory_limit": "2 GiB"}) as (s, [a, b, c]): with Client(s["address"], loop=loop) as c: - for func in funcs: - text = func(c) - assert c.scheduler.address in text - assert "threads=3" in text or "Total threads: " in text - assert "6.00 GiB" in text - if "" in text + assert "6.00 GiB" in text + if " Date: Fri, 19 Aug 2022 13:13:05 +0100 Subject: [PATCH 2/3] lingering clients are no-longer deprecated --- distributed/tests/test_utils_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 08529aa54b8..f3fd449e16b 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -273,11 +273,9 @@ def test_new_config(): def test_lingering_client(): @gen_cluster() async def f(s, a, b): - with pytest.warns( - DeprecationWarning, - match=r"await Client\(\) is deprecated, use async with Client\(\)", - ): - await Client(s.address, asynchronous=True) + # TODO: force async-with here? + # see https://github.com/dask/distributed/issues/6616 + await Client(s.address, asynchronous=True) f() From a10d2b793b2d3b167f66c020eb7f6ac9cf5e61e2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 19 Aug 2022 13:24:30 +0100 Subject: [PATCH 3/3] Apply suggestions from code review --- distributed/tests/test_utils_test.py | 2 +- distributed/tests/test_worker.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index f3fd449e16b..2147982f750 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -284,7 +284,7 @@ async def f(s, a, b): def test_lingering_client_2(loop): - # assert where the client went + # TODO: assert where the client went with cluster() as (s, [a, b]): client = Client(s["address"], loop=loop) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 49c4cdeb447..572439ed992 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1019,9 +1019,10 @@ def f(x): @gen_cluster(client=True) async def test_get_client_coroutine(c, s, a, b): async def f(): - # TODO: `await get_client()` raises a deprecation warning to use - # `async with get_client()` and that will kill the workers' client + # TODO: the existence of `await get_client()` implies the possibility + # of `async with get_client()` and that will kill the workers' client # if you do that. We really don't want users to do that. + # https://github.com/dask/distributed/pull/6921/ client = get_client() future = client.submit(inc, 10) result = await future