From b7e184a02282ea6fb6a7ec540c2c2fd75211bafb Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 22 Aug 2022 14:16:05 +0100 Subject: [PATCH] use `async with Client:` in tests (#6921) --- distributed/deploy/tests/test_adaptive.py | 93 ++-- distributed/deploy/tests/test_local.py | 64 ++- distributed/tests/test_client.py | 567 ++++++++++------------ distributed/tests/test_failed_workers.py | 56 ++- distributed/tests/test_publish.py | 100 ++-- distributed/tests/test_queues.py | 62 ++- distributed/tests/test_scheduler.py | 13 +- distributed/tests/test_utils_test.py | 3 + distributed/tests/test_variable.py | 62 ++- distributed/tests/test_worker.py | 6 +- distributed/utils_test.py | 159 +++--- 11 files changed, 570 insertions(+), 615 deletions(-) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 1248db0f0d8..04b518b063b 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -117,50 +117,47 @@ def scale_up(self, n, **kwargs): @gen_test() async def test_min_max(): - cluster = await LocalCluster( + async with LocalCluster( n_workers=0, silence_logs=False, processes=False, dashboard_address=":0", asynchronous=True, threads_per_worker=1, - ) - try: + ) as cluster: adapt = cluster.adapt(minimum=1, maximum=2, interval="20 ms", wait_count=10) - c = await Client(cluster, asynchronous=True) - - start = time() - while not cluster.scheduler.workers: - await asyncio.sleep(0.01) - assert time() < start + 1 - - await asyncio.sleep(0.2) - assert len(cluster.scheduler.workers) == 1 - assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} + async with Client(cluster, asynchronous=True) as c: + start = time() + while not cluster.scheduler.workers: + await asyncio.sleep(0.01) + assert time() < start + 1 - futures = c.map(slowinc, range(100), delay=0.1) + await asyncio.sleep(0.2) + assert len(cluster.scheduler.workers) == 1 + assert len(adapt.log) == 1 and adapt.log[-1][1] == {"status": "up", "n": 1} - start = time() - while len(cluster.scheduler.workers) < 2: - await asyncio.sleep(0.01) - assert time() < start + 1 + futures = c.map(slowinc, range(100), delay=0.1) - assert len(cluster.scheduler.workers) == 2 - await asyncio.sleep(0.5) - assert len(cluster.scheduler.workers) == 2 - assert len(cluster.workers) == 2 - assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log) + start = time() + while len(cluster.scheduler.workers) < 2: + await asyncio.sleep(0.01) + assert time() < start + 1 + + assert len(cluster.scheduler.workers) == 2 + await asyncio.sleep(0.5) + assert len(cluster.scheduler.workers) == 2 + assert len(cluster.workers) == 2 + assert len(adapt.log) == 2 and all( + d["status"] == "up" for _, d in adapt.log + ) - del futures + del futures - start = time() - while len(cluster.scheduler.workers) != 1: - await asyncio.sleep(0.01) - assert time() < start + 2 - assert adapt.log[-1][1]["status"] == "down" - finally: - await c.close() - await cluster.close() + start = time() + while len(cluster.scheduler.workers) != 1: + await asyncio.sleep(0.01) + assert time() < start + 2 + assert adapt.log[-1][1]["status"] == "down" @gen_test() @@ -194,16 +191,14 @@ async def test_adapt_quickly(): Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - cluster = await LocalCluster( + async with LocalCluster( n_workers=0, asynchronous=True, processes=False, silence_logs=False, dashboard_address=":0", - ) - client = await Client(cluster, asynchronous=True) - adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) - try: + ) as cluster, Client(cluster, asynchronous=True) as client: + adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) future = client.submit(slowinc, 1, delay=0.100) await wait(future) assert len(adapt.log) == 1 @@ -241,9 +236,6 @@ async def test_adapt_quickly(): await asyncio.sleep(0.1) assert len(cluster.workers) == 1 - finally: - await client.close() - await cluster.close() @gen_test() @@ -255,20 +247,19 @@ async def test_adapt_down(): processes=False, silence_logs=False, dashboard_address=":0", - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - cluster.adapt(interval="20ms", maximum=5) + ) as cluster, Client(cluster, asynchronous=True) as client: + cluster.adapt(interval="20ms", maximum=5) - futures = client.map(slowinc, range(1000), delay=0.1) - while len(cluster.scheduler.workers) < 5: - await asyncio.sleep(0.1) + futures = client.map(slowinc, range(1000), delay=0.1) + while len(cluster.scheduler.workers) < 5: + await asyncio.sleep(0.1) - cluster.adapt(maximum=2) + cluster.adapt(maximum=2) - start = time() - while len(cluster.scheduler.workers) != 2: - await asyncio.sleep(0.1) - assert time() < start + 60 + start = time() + while len(cluster.scheduler.workers) != 2: + await asyncio.sleep(0.1) + assert time() < start + 60 @gen_test() diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 1da51d9f4e4..bc64cfca0b6 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -5,11 +5,11 @@ import sys from threading import Lock from time import sleep +from unittest import mock from urllib.parse import urlparse import pytest from tornado.httpclient import AsyncHTTPClient -from tornado.ioloop import IOLoop from dask.system import CPU_COUNT @@ -245,19 +245,21 @@ def test_Client_solo(loop): @gen_test() async def test_duplicate_clients(): pytest.importorskip("bokeh") - c1 = await Client( + async with Client( processes=False, silence_logs=False, dashboard_address=9876, asynchronous=True - ) - with pytest.warns(Warning) as info: - c2 = await Client( - processes=False, - silence_logs=False, - dashboard_address=9876, - asynchronous=True, - ) - - assert "dashboard" in c1.cluster.scheduler.services - assert "dashboard" in c2.cluster.scheduler.services + ) as c1: + c1_services = c1.cluster.scheduler.services + with pytest.warns(Warning) as info: + async with Client( + processes=False, + silence_logs=False, + dashboard_address=9876, + asynchronous=True, + ) as c2: + c2_services = c2.cluster.scheduler.services + + assert c1_services == {"dashboard": mock.ANY} + assert c2_services == {"dashboard": mock.ANY} assert any( all( @@ -266,8 +268,6 @@ async def test_duplicate_clients(): ) for msg in info.list ) - await c1.close() - await c2.close() def test_Client_kwargs(loop): @@ -824,35 +824,29 @@ class MyCluster(LocalCluster): def scale_down(self, *args, **kwargs): pass - loop = IOLoop.current() - cluster = await MyCluster( + async with MyCluster( n_workers=0, processes=False, silence_logs=False, dashboard_address=":0", - loop=loop, + loop=None, asynchronous=True, - ) - c = await Client(cluster, asynchronous=True) - - assert not cluster.workers - - await cluster.scale(2) + ) as cluster, Client(cluster, asynchronous=True) as c: + assert not cluster.workers - start = time() - while len(cluster.scheduler.workers) != 2: - await asyncio.sleep(0.01) - assert time() < start + 3 + await cluster.scale(2) - await cluster.scale(1) + start = time() + while len(cluster.scheduler.workers) != 2: + await asyncio.sleep(0.01) + assert time() < start + 3 - start = time() - while len(cluster.scheduler.workers) != 1: - await asyncio.sleep(0.01) - assert time() < start + 3 + await cluster.scale(1) - await c.close() - await cluster.close() + start = time() + while len(cluster.scheduler.workers) != 1: + await asyncio.sleep(0.01) + assert time() < start + 3 def test_local_tls_restart(loop): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fbf2af7f99d..334b43f3fdd 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6,6 +6,7 @@ import gc import inspect import logging +import operator import os import pathlib import pickle @@ -534,17 +535,14 @@ async def test_exceptions(c, s, a, b): @gen_cluster() async def test_gc(s, a, b): - c = await Client(s.address, asynchronous=True) - - x = c.submit(inc, 10) - await x - assert s.tasks[x.key].who_has - x.__del__() - await async_wait_for( - lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 - ) - - await c.close() + async with Client(s.address, asynchronous=True) as c: + x = c.submit(inc, 10) + await x + assert s.tasks[x.key].who_has + x.__del__() + await async_wait_for( + lambda: x.key not in s.tasks or not s.tasks[x.key].who_has, timeout=0.3 + ) def test_thread(c): @@ -587,13 +585,12 @@ async def test_gather(c, s, a, b): @gen_cluster(client=True) async def test_gather_mismatched_client(c, s, a, b): - c2 = await Client(s.address, asynchronous=True) + async with Client(s.address, asynchronous=True) as c2: + x = c.submit(inc, 10) + y = c2.submit(inc, 5) - x = c.submit(inc, 10) - y = c2.submit(inc, 5) - - with pytest.raises(ValueError, match="Futures created by another client"): - await c.gather([x, y]) + with pytest.raises(ValueError, match="Futures created by another client"): + await c.gather([x, y]) @gen_cluster(client=True) @@ -1067,20 +1064,16 @@ async def test_map_quotes(c, s, a, b): @gen_cluster() async def test_two_consecutive_clients_share_results(s, a, b): - c = await Client(s.address, asynchronous=True) - - x = c.submit(random.randint, 0, 1000, pure=True) - xx = await x - - f = await Client(s.address, asynchronous=True) + async with Client(s.address, asynchronous=True) as c: - y = f.submit(random.randint, 0, 1000, pure=True) - yy = await y + x = c.submit(random.randint, 0, 1000, pure=True) + xx = await x - assert xx == yy + async with Client(s.address, asynchronous=True) as f: + y = f.submit(random.randint, 0, 1000, pure=True) + yy = await y - await c.close() - await f.close() + assert xx == yy @gen_cluster(client=True) @@ -1423,14 +1416,10 @@ async def test_scatter_direct(c, s, a, b): @gen_cluster() async def test_scatter_direct_2(s, a, b): - c = await Client(s.address, asynchronous=True, heartbeat_interval=10) - - last = s.clients[c.id].last_seen - - while s.clients[c.id].last_seen == last: - await asyncio.sleep(0.10) - - await c.close() + async with Client(s.address, asynchronous=True, heartbeat_interval=10) as c: + last = s.clients[c.id].last_seen + while s.clients[c.id].last_seen == last: + await asyncio.sleep(0.10) @gen_cluster(client=True) @@ -2014,21 +2003,26 @@ async def test_badly_serialized_input_stderr(capsys, c): assert future.status == "error" -def test_repr(loop): - funcs = [str, repr, lambda x: x._repr_html_()] +@pytest.mark.parametrize( + "func", + [ + str, + repr, + operator.methodcaller("_repr_html_"), + ], +) +def test_repr(loop, func): with cluster(nworkers=3, worker_kwargs={"memory_limit": "2 GiB"}) as (s, [a, b, c]): with Client(s["address"], loop=loop) as c: - for func in funcs: - text = func(c) - assert c.scheduler.address in text - assert "threads=3" in text or "Total threads: " in text - assert "6.00 GiB" in text - if "" in text + assert "6.00 GiB" in text + if "